// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// arrayCodec is the Codec used for bsoncore.Array values.
type arrayCodec struct{}
// EncodeValue is the ValueEncoder for bsoncore.Array values.
func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreArray {
return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
arr := val.Interface().(bsoncore.Array)
return copyArrayFromBytes(vw, arr)
}
// DecodeValue is the ValueDecoder for bsoncore.Array values.
func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tCoreArray {
return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
arr, err := appendArrayBytes(val.Interface().(bsoncore.Array), vr)
val.Set(reflect.ValueOf(arr))
return err
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path"
"reflect"
"strconv"
"strings"
"testing"
"unicode"
"unicode/utf8"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/require"
)
type testCase struct {
Description string `json:"description"`
BsonType string `json:"bson_type"`
TestKey *string `json:"test_key"`
Valid []validityTestCase `json:"valid"`
DecodeErrors []decodeErrorTestCase `json:"decodeErrors"`
ParseErrors []parseErrorTestCase `json:"parseErrors"`
Deprecated *bool `json:"deprecated"`
}
type validityTestCase struct {
Description string `json:"description"`
CanonicalBson string `json:"canonical_bson"`
CanonicalExtJSON string `json:"canonical_extjson"`
RelaxedExtJSON *string `json:"relaxed_extjson"`
DegenerateBSON *string `json:"degenerate_bson"`
DegenerateExtJSON *string `json:"degenerate_extjson"`
ConvertedBSON *string `json:"converted_bson"`
ConvertedExtJSON *string `json:"converted_extjson"`
Lossy *bool `json:"lossy"`
}
type decodeErrorTestCase struct {
Description string `json:"description"`
Bson string `json:"bson"`
}
type parseErrorTestCase struct {
Description string `json:"description"`
String string `json:"string"`
}
const dataDir = "../testdata/bson-corpus/"
func findJSONFilesInDir(dir string) ([]string, error) {
files := make([]string, 0)
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
for _, entry := range entries {
if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
continue
}
files = append(files, entry.Name())
}
return files, nil
}
// seedExtJSON will add the byte representation of the "extJSON" string to the fuzzer's coprus.
func seedExtJSON(f *testing.F, extJSON string, extJSONType string, desc string) {
jbytes, err := jsonToBytes(extJSON, extJSONType, desc)
if err != nil {
f.Fatalf("failed to convert JSON to bytes: %v", err)
}
f.Add(jbytes)
}
// seedTestCase will add the byte representation for each "extJSON" string of each valid test case to the fuzzer's
// corpus.
func seedTestCase(f *testing.F, tcase *testCase) {
for _, vtc := range tcase.Valid {
seedExtJSON(f, vtc.CanonicalExtJSON, "canonical", vtc.Description)
// Seed the relaxed extended JSON.
if vtc.RelaxedExtJSON != nil {
seedExtJSON(f, *vtc.RelaxedExtJSON, "relaxed", vtc.Description)
}
// Seed the degenerate extended JSON.
if vtc.DegenerateExtJSON != nil {
seedExtJSON(f, *vtc.DegenerateExtJSON, "degenerate", vtc.Description)
}
// Seed the converted extended JSON.
if vtc.ConvertedExtJSON != nil {
seedExtJSON(f, *vtc.ConvertedExtJSON, "converted", vtc.Description)
}
}
}
// seedBSONCorpus will unmarshal the data from "testdata/bson-corpus" into a slice of "testCase" structs and then
// marshal the "*_extjson" field of each "validityTestCase" into a slice of bytes to seed the fuzz corpus.
func seedBSONCorpus(f *testing.F) {
fileNames, err := findJSONFilesInDir(dataDir)
if err != nil {
f.Fatalf("failed to find JSON files in directory %q: %v", dataDir, err)
}
for _, fileName := range fileNames {
filePath := path.Join(dataDir, fileName)
file, err := os.Open(filePath)
if err != nil {
f.Fatalf("failed to open file %q: %v", filePath, err)
}
var tcase testCase
if err := json.NewDecoder(file).Decode(&tcase); err != nil {
f.Fatal(err)
}
seedTestCase(f, &tcase)
}
}
func needsEscapedUnicode(bsonType string) bool {
return bsonType == "0x02" || bsonType == "0x0D" || bsonType == "0x0E" || bsonType == "0x0F"
}
func unescapeUnicode(s, bsonType string) string {
if !needsEscapedUnicode(bsonType) {
return s
}
newS := ""
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case '\\':
switch s[i+1] {
case 'u':
us := s[i : i+6]
u, err := strconv.Unquote(strings.Replace(strconv.Quote(us), `\\u`, `\u`, 1))
if err != nil {
return ""
}
for _, r := range u {
if r < ' ' {
newS += fmt.Sprintf(`\u%04x`, r)
} else {
newS += string(r)
}
}
i += 5
default:
newS += string(c)
}
default:
if c > unicode.MaxASCII {
r, size := utf8.DecodeRune([]byte(s[i:]))
newS += string(r)
i += size - 1
} else {
newS += string(c)
}
}
}
return newS
}
func normalizeCanonicalDouble(t *testing.T, key string, cEJ string) string {
// Unmarshal string into map
cEJMap := make(map[string]map[string]string)
err := json.Unmarshal([]byte(cEJ), &cEJMap)
require.NoError(t, err)
// Parse the float contained by the map.
expectedString := cEJMap[key]["$numberDouble"]
expectedFloat, err := strconv.ParseFloat(expectedString, 64)
require.NoError(t, err)
// Normalize the string
return fmt.Sprintf(`{"%s":{"$numberDouble":"%s"}}`, key, formatDouble(expectedFloat))
}
func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
// Unmarshal string into map
rEJMap := make(map[string]float64)
err := json.Unmarshal([]byte(rEJ), &rEJMap)
if err != nil {
return normalizeCanonicalDouble(t, key, rEJ)
}
// Parse the float contained by the map.
expectedFloat := rEJMap[key]
// Normalize the string
return fmt.Sprintf(`{"%s":%s}`, key, formatDouble(expectedFloat))
}
// bsonToNative decodes the BSON bytes (b) into a native Document
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
var doc D
err := Unmarshal(b, &doc)
require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType)
return doc
}
// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
// canonical BSON (cB)
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
actual, err := Marshal(doc)
require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType)
if diff := cmp.Diff(cB, actual); diff != "" {
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
testDesc, docSrcDesc, cB, actual)
t.FailNow()
}
}
// jsonToNative decodes the extended JSON string (ej) into a native Document
func jsonToNative(ej, ejType, testDesc string) (D, error) {
var doc D
if err := UnmarshalExtJSON([]byte(ej), ejType != "relaxed", &doc); err != nil {
return nil, fmt.Errorf("%s: decoding %s extended JSON: %w", testDesc, ejType, err)
}
return doc, nil
}
// jsonToBytes decodes the extended JSON string (ej) into canonical BSON and then encodes it into a byte slice.
func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
native, err := jsonToNative(ej, ejType, testDesc)
if err != nil {
return nil, err
}
b, err := Marshal(native)
if err != nil {
return nil, fmt.Errorf("%s: encoding %s BSON: %w", testDesc, ejType, err)
}
return b, nil
}
// nativeToJSON encodes the native Document (doc) into an extended JSON string
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType)
if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
testDesc, ejType, docSrcDesc, ejShortName, diff)
t.FailNow()
}
}
func runTest(t *testing.T, file string) {
filepath := path.Join(dataDir, file)
content, err := os.ReadFile(filepath)
require.NoError(t, err)
t.Run(file, func(t *testing.T) {
var test testCase
require.NoError(t, json.Unmarshal(content, &test))
t.Run("valid", func(t *testing.T) {
for _, v := range test.Valid {
t.Run(v.Description, func(t *testing.T) {
// get canonical BSON
cB, err := hex.DecodeString(v.CanonicalBson)
require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description)
// get canonical extended JSON
var compactEJ bytes.Buffer
require.NoError(t, json.Compact(&compactEJ, []byte(v.CanonicalExtJSON)))
cEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
if test.BsonType == "0x01" {
cEJ = normalizeCanonicalDouble(t, *test.TestKey, cEJ)
}
/*** canonical BSON round-trip tests ***/
doc := bsonToNative(t, cB, "canonical", v.Description)
// native_to_bson(bson_to_native(cB)) = cB
nativeToBSON(t, cB, doc, v.Description, "canonical", "bson_to_native(cB)")
// native_to_canonical_extended_json(bson_to_native(cB)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "bson_to_native(cB)")
// native_to_relaxed_extended_json(bson_to_native(cB)) = rEJ (if rEJ exists)
if v.RelaxedExtJSON != nil {
var compactEJ bytes.Buffer
require.NoError(t, json.Compact(&compactEJ, []byte(*v.RelaxedExtJSON)))
rEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
if test.BsonType == "0x01" {
rEJ = normalizeRelaxedDouble(t, *test.TestKey, rEJ)
}
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "rEJ", "bson_to_native(cB)")
/*** relaxed extended JSON round-trip tests (if exists) ***/
doc, err = jsonToNative(rEJ, "relaxed", v.Description)
require.NoError(t, err)
// native_to_relaxed_extended_json(json_to_native(rEJ)) = rEJ
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "eJR", "json_to_native(rEJ)")
}
/*** canonical extended JSON round-trip tests ***/
doc, err = jsonToNative(cEJ, "canonical", v.Description)
require.NoError(t, err)
// native_to_canonical_extended_json(json_to_native(cEJ)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "json_to_native(cEJ)")
// native_to_bson(json_to_native(cEJ)) = cb (unless lossy)
if v.Lossy == nil || !*v.Lossy {
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(cEJ)")
}
/*** degenerate BSON round-trip tests (if exists) ***/
if v.DegenerateBSON != nil {
dB, err := hex.DecodeString(*v.DegenerateBSON)
require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description)
doc = bsonToNative(t, dB, "degenerate", v.Description)
// native_to_bson(bson_to_native(dB)) = cB
nativeToBSON(t, cB, doc, v.Description, "degenerate", "bson_to_native(dB)")
}
/*** degenerate JSON round-trip tests (if exists) ***/
if v.DegenerateExtJSON != nil {
var compactEJ bytes.Buffer
require.NoError(t, json.Compact(&compactEJ, []byte(*v.DegenerateExtJSON)))
dEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
if test.BsonType == "0x01" {
dEJ = normalizeCanonicalDouble(t, *test.TestKey, dEJ)
}
doc, err = jsonToNative(dEJ, "degenerate canonical", v.Description)
require.NoError(t, err)
// native_to_canonical_extended_json(json_to_native(dEJ)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
// native_to_bson(json_to_native(dEJ)) = cB (unless lossy)
if v.Lossy == nil || !*v.Lossy {
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(dEJ)")
}
}
})
}
})
t.Run("decode error", func(t *testing.T) {
for _, d := range test.DecodeErrors {
t.Run(d.Description, func(t *testing.T) {
b, err := hex.DecodeString(d.Bson)
require.NoError(t, err, d.Description)
var doc D
err = Unmarshal(b, &doc)
// The driver unmarshals invalid UTF-8 strings without error. Loop over the unmarshalled elements
// and assert that there was no error if any of the string or DBPointer values contain invalid UTF-8
// characters.
for _, elem := range doc {
value := reflect.ValueOf(elem.Value)
invalidString := (value.Kind() == reflect.String) && !utf8.ValidString(value.String())
dbPtr, ok := elem.Value.(DBPointer)
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)
if invalidString || invalidDBPtr {
require.NoError(t, err, d.Description)
return
}
}
require.Errorf(t, err, "%s: expected decode error", d.Description)
})
}
})
t.Run("parse error", func(t *testing.T) {
for _, p := range test.ParseErrors {
t.Run(p.Description, func(t *testing.T) {
s := unescapeUnicode(p.String, test.BsonType)
if test.BsonType == "0x13" {
s = fmt.Sprintf(`{"decimal128": {"$numberDecimal": "%s"}}`, s)
}
switch test.BsonType {
case "0x00", "0x05", "0x13":
var doc D
err := UnmarshalExtJSON([]byte(s), true, &doc)
// Null bytes are validated when marshaling to BSON
if strings.Contains(p.Description, "Null") {
_, err = Marshal(doc)
}
require.Errorf(t, err, "%s: expected parse error", p.Description)
default:
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
t.Fail()
}
})
}
})
})
}
func TestBSONCorpus(t *testing.T) {
jsonFiles, err := findJSONFilesInDir(dataDir)
require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err)
for _, file := range jsonFiles {
runTest(t, file)
}
}
func TestRelaxedUUIDValidation(t *testing.T) {
testCases := []struct {
description string
canonicalExtJSON string
degenerateExtJSON string
expectedErr string
}{
{
"valid uuid",
"{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"04\"}}}",
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035d4\"}}",
"",
},
{
"invalid uuid--no hyphens",
"",
"{\"x\" : { \"$uuid\" : \"73ffd26444b34c6990e8e7d1dfc035d4\"}}",
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
},
{
"invalid uuid--trailing hyphens",
"",
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035--\"}}",
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
},
{
"invalid uuid--malformed hex",
"",
"{\"x\" : { \"$uuid\" : \"q3@fd26l-44b3-4c69-90e8-e7d1dfc035d4\"}}",
"$uuid value does not follow RFC 4122 format regarding hex bytes: encoding/hex: invalid byte: U+0071 'q'",
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
// get canonical extended JSON (if provided)
cEJ := ""
if tc.canonicalExtJSON != "" {
var compactCEJ bytes.Buffer
require.NoError(t, json.Compact(&compactCEJ, []byte(tc.canonicalExtJSON)))
cEJ = unescapeUnicode(compactCEJ.String(), "0x05")
}
// get degenerate extended JSON
var compactDEJ bytes.Buffer
require.NoError(t, json.Compact(&compactDEJ, []byte(tc.degenerateExtJSON)))
dEJ := unescapeUnicode(compactDEJ.String(), "0x05")
// convert dEJ to native doc
var doc D
err := UnmarshalExtJSON([]byte(dEJ), true, &doc)
if tc.expectedErr != "" {
assert.Equal(t, tc.expectedErr, err.Error(), "expected error %v, got %v", tc.expectedErr, err)
} else {
assert.Nil(t, err, "expected no error, got error: %v", err)
// Marshal doc into extended JSON and compare with cEJ
nativeToJSON(t, cEJ, doc, tc.description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
}
})
}
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
"strings"
)
var (
emptyValue = reflect.Value{}
)
// ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be
// encoded by the ValueEncoder.
type ValueEncoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vee ValueEncoderError) Error() string {
typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds))
for _, t := range vee.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vee.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vee.Received.Kind().String()
if vee.Received.IsValid() {
received = vee.Received.Type().String()
}
return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received)
}
// ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be
// decoded by the ValueDecoder.
type ValueDecoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vde ValueDecoderError) Error() string {
typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds))
for _, t := range vde.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vde.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vde.Received.Kind().String()
if vde.Received.IsValid() {
received = vde.Received.Type().String()
}
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
}
// EncodeContext is the contextual information required for a Codec to encode a
// value.
type EncodeContext struct {
*Registry
// minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits)
// that can represent the integer value.
minSize bool
errorOnInlineDuplicates bool
stringifyMapKeysWithFmt bool
nilMapAsEmpty bool
nilSliceAsEmpty bool
nilByteSliceAsEmpty bool
omitZeroStruct bool
omitEmpty bool
useJSONStructTags bool
}
// DecodeContext is the contextual information required for a Codec to decode a
// value.
type DecodeContext struct {
*Registry
// truncate, if true, instructs decoders to to truncate the fractional part of BSON "double"
// values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to
// BSON "decimal128" values.
truncate bool
// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
// error.
defaultDocumentType reflect.Type
binaryAsSlice bool
// a false value results in a decoding error.
objectIDAsHexString bool
useJSONStructTags bool
useLocalTimeZone bool
zeroMaps bool
zeroStructs bool
}
// ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON.
// The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the
// EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc
// is provided to allow use of a function with the correct signature as a ValueEncoder. An
// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and
// to provide configuration information.
type ValueEncoder interface {
EncodeValue(EncodeContext, ValueWriter, reflect.Value) error
}
// ValueEncoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueEncoder.
type ValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error
// EncodeValue implements the ValueEncoder interface.
func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
return fn(ec, vw, val)
}
// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type.
// Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc,
// ValueDecoderFunc is provided to allow the use of a function with the correct signature as a
// ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the
// EncodeContext.
type ValueDecoder interface {
DecodeValue(DecodeContext, ValueReader, reflect.Value) error
}
// ValueDecoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueDecoder.
type ValueDecoderFunc func(DecodeContext, ValueReader, reflect.Value) error
// DecodeValue implements the ValueDecoder interface.
func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
return fn(dc, vr, val)
}
// typeDecoder is the interface implemented by types that can handle the decoding of a value given its type.
type typeDecoder interface {
decodeType(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
}
// typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder.
type typeDecoderFunc func(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
return fn(dc, vr, t)
}
// decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder.
type decodeAdapter struct {
ValueDecoderFunc
typeDecoderFunc
}
var _ ValueDecoder = decodeAdapter{}
var _ typeDecoder = decodeAdapter{}
func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if td, _ := vd.(typeDecoder); td != nil {
val, err := td.decodeType(dc, vr, t)
if err == nil && val.Type() != t {
// This conversion step is necessary for slices and maps. If a user declares variables like:
//
// type myBool bool
// var m map[string]myBool
//
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
// because we'll try to assign a value of type bool to one of type myBool.
val = val.Convert(t)
}
return val, err
}
val := reflect.New(t).Elem()
err := vd.DecodeValue(dc, vr, val)
return val, err
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
)
// byteSliceCodec is the Codec used for []byte values.
type byteSliceCodec struct {
// encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values
// instead of BSON null.
encodeNilAsEmpty bool
}
// Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be
// used by collection type decoders (e.g. map, slice, etc) to set individual values in a
// collection.
var _ typeDecoder = &byteSliceCodec{}
// EncodeValue is the ValueEncoder for []byte.
func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tByteSlice {
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty {
return vw.WriteNull()
}
return vw.WriteBinary(val.Interface().([]byte))
}
func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tByteSlice {
return emptyValue, ValueDecoderError{
Name: "ByteSliceDecodeValue",
Types: []reflect.Type{tByteSlice},
Received: reflect.Zero(t),
}
}
var data []byte
var err error
switch vrType := vr.Type(); vrType {
case TypeString:
str, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
data = []byte(str)
case TypeSymbol:
sym, err := vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
data = []byte(sym)
case TypeBinary:
var subtype byte
data, subtype, err = vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"}
}
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(data), nil
}
// DecodeValue is the ValueDecoder for []byte.
func (bsc *byteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tByteSlice {
return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
elem, err := bsc.decodeType(dc, vr, tByteSlice)
if err != nil {
return err
}
val.Set(elem)
return nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"sync"
"sync/atomic"
)
// Runtime check that the kind encoder and decoder caches can store any valid
// reflect.Kind constant.
func init() {
if s := reflect.Kind(len(kindEncoderCache{}.entries)).String(); s != "kind27" {
panic("The capacity of kindEncoderCache is too small.\n" +
"This is due to a new type being added to reflect.Kind.")
}
}
// statically assert array size
var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer]
var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer]
type typeEncoderCache struct {
cache sync.Map // map[reflect.Type]ValueEncoder
}
func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) {
c.cache.Store(rt, enc)
}
func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) {
if v, _ := c.cache.Load(rt); v != nil {
return v.(ValueEncoder), true
}
return nil, false
}
func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder {
if v, loaded := c.cache.LoadOrStore(rt, enc); loaded {
enc = v.(ValueEncoder)
}
return enc
}
func (c *typeEncoderCache) Clone() *typeEncoderCache {
cc := new(typeEncoderCache)
c.cache.Range(func(k, v interface{}) bool {
if k != nil && v != nil {
cc.cache.Store(k, v)
}
return true
})
return cc
}
type typeDecoderCache struct {
cache sync.Map // map[reflect.Type]ValueDecoder
}
func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) {
c.cache.Store(rt, dec)
}
func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) {
if v, _ := c.cache.Load(rt); v != nil {
return v.(ValueDecoder), true
}
return nil, false
}
func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder {
if v, loaded := c.cache.LoadOrStore(rt, dec); loaded {
dec = v.(ValueDecoder)
}
return dec
}
func (c *typeDecoderCache) Clone() *typeDecoderCache {
cc := new(typeDecoderCache)
c.cache.Range(func(k, v interface{}) bool {
if k != nil && v != nil {
cc.cache.Store(k, v)
}
return true
})
return cc
}
// atomic.Value requires that all calls to Store() have the same concrete type
// so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type
// is always the same (since different concrete types may implement the
// ValueEncoder interface).
type kindEncoderCacheEntry struct {
enc ValueEncoder
}
type kindEncoderCache struct {
entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry
}
func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) {
if enc != nil && rt < reflect.Kind(len(c.entries)) {
c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc})
}
}
func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) {
if rt < reflect.Kind(len(c.entries)) {
if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok {
return ent.enc, ent.enc != nil
}
}
return nil, false
}
func (c *kindEncoderCache) Clone() *kindEncoderCache {
cc := new(kindEncoderCache)
for i, v := range c.entries {
if val := v.Load(); val != nil {
cc.entries[i].Store(val)
}
}
return cc
}
// atomic.Value requires that all calls to Store() have the same concrete type
// so we wrap the ValueDecoder with a kindDecoderCacheEntry to ensure the type
// is always the same (since different concrete types may implement the
// ValueDecoder interface).
type kindDecoderCacheEntry struct {
dec ValueDecoder
}
type kindDecoderCache struct {
entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry
}
func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) {
if rt < reflect.Kind(len(c.entries)) {
c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec})
}
}
func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) {
if rt < reflect.Kind(len(c.entries)) {
if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok {
return ent.dec, ent.dec != nil
}
}
return nil, false
}
func (c *kindDecoderCache) Clone() *kindDecoderCache {
cc := new(kindDecoderCache)
for i, v := range c.entries {
if val := v.Load(); val != nil {
cc.entries[i].Store(val)
}
}
return cc
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
)
// condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder.
type condAddrEncoder struct {
canAddrEnc ValueEncoder
elseEnc ValueEncoder
}
var _ ValueEncoder = &condAddrEncoder{}
// newCondAddrEncoder returns an condAddrEncoder.
func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder {
encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
return &encoder
}
// EncodeValue is the ValueEncoderFunc for a value that may be addressable.
func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if val.CanAddr() {
return cae.canAddrEnc.EncodeValue(ec, vw, val)
}
if cae.elseEnc != nil {
return cae.elseEnc.EncodeValue(ec, vw, val)
}
return errNoEncoder{Type: val.Type()}
}
// condAddrDecoder is the decoder used when a pointer to the value has a decoder.
type condAddrDecoder struct {
canAddrDec ValueDecoder
elseDec ValueDecoder
}
var _ ValueDecoder = &condAddrDecoder{}
// newCondAddrDecoder returns an CondAddrDecoder.
func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder {
decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec}
return &decoder
}
// DecodeValue is the ValueDecoderFunc for a value that may be addressable.
func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if val.CanAddr() {
return cad.canAddrDec.DecodeValue(dc, vr, val)
}
if cad.elseDec != nil {
return cad.elseDec.DecodeValue(dc, vr, val)
}
return errNoDecoder{Type: val.Type()}
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"io"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// copyDocument handles copying one document from the src to the dst.
func copyDocument(dst ValueWriter, src ValueReader) error {
dr, err := src.ReadDocument()
if err != nil {
return err
}
dw, err := dst.WriteDocument()
if err != nil {
return err
}
return copyDocumentCore(dw, dr)
}
// copyArrayFromBytes copies the values from a BSON array represented as a
// []byte to a ValueWriter.
func copyArrayFromBytes(dst ValueWriter, src []byte) error {
aw, err := dst.WriteArray()
if err != nil {
return err
}
err = copyBytesToArrayWriter(aw, src)
if err != nil {
return err
}
return aw.WriteArrayEnd()
}
// copyDocumentFromBytes copies the values from a BSON document represented as a
// []byte to a ValueWriter.
func copyDocumentFromBytes(dst ValueWriter, src []byte) error {
dw, err := dst.WriteDocument()
if err != nil {
return err
}
err = copyBytesToDocumentWriter(dw, src)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
type writeElementFn func(key string) (ValueWriter, error)
// copyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an
// ArrayWriter.
func copyBytesToArrayWriter(dst ArrayWriter, src []byte) error {
wef := func(_ string) (ValueWriter, error) {
return dst.WriteArrayElement()
}
return copyBytesToValueWriter(src, wef)
}
// copyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a
// DocumentWriter.
func copyBytesToDocumentWriter(dst DocumentWriter, src []byte) error {
wef := func(key string) (ValueWriter, error) {
return dst.WriteDocumentElement(key)
}
return copyBytesToValueWriter(src, wef)
}
func copyBytesToValueWriter(src []byte, wef writeElementFn) error {
// TODO(skriptble): Create errors types here. Anything that is a tag should be a property.
length, rem, ok := bsoncore.ReadLength(src)
if !ok {
return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src))
}
if len(src) < int(length) {
return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length)
}
rem = rem[:length-4]
var t bsoncore.Type
var key string
var val bsoncore.Value
for {
t, rem, ok = bsoncore.ReadType(rem)
if !ok {
return io.EOF
}
if t == bsoncore.Type(0) {
if len(rem) != 0 {
return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem)
}
break
}
key, rem, ok = bsoncore.ReadKey(rem)
if !ok {
return fmt.Errorf("invalid key found. remaining bytes=%v", rem)
}
// write as either array element or document element using writeElementFn
vw, err := wef(key)
if err != nil {
return err
}
val, rem, ok = bsoncore.ReadValue(rem, t)
if !ok {
return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t)
}
err = copyValueFromBytes(vw, Type(t), val.Data)
if err != nil {
return err
}
}
return nil
}
// copyDocumentToBytes copies an entire document from the ValueReader and
// returns it as bytes.
func copyDocumentToBytes(src ValueReader) ([]byte, error) {
return appendDocumentBytes(nil, src)
}
// appendDocumentBytes functions the same as CopyDocumentToBytes, but will
// append the result to dst.
func appendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(bytesReader); ok {
_, dst, err := br.readValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(dst)
err := copyDocument(vw, src)
dst = vw.buf
return dst, err
}
// appendArrayBytes copies an array from the ValueReader to dst.
func appendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(bytesReader); ok {
_, dst, err := br.readValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(dst)
err := copyArray(vw, src)
dst = vw.buf
return dst, err
}
// copyValueFromBytes will write the value represtend by t and src to dst.
func copyValueFromBytes(dst ValueWriter, t Type, src []byte) error {
if wvb, ok := dst.(bytesWriter); ok {
return wvb.writeValueBytes(t, src)
}
vr := newDocumentReader(bytes.NewReader(src))
vr.pushElement(t)
return copyValue(dst, vr)
}
// copyValueToBytes copies a value from src and returns it as a Type and a
// []byte.
func copyValueToBytes(src ValueReader) (Type, []byte, error) {
if br, ok := src.(bytesReader); ok {
return br.readValueBytes(nil)
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(nil)
vw.push(mElement)
err := copyValue(vw, src)
if err != nil {
return 0, nil, err
}
return Type(vw.buf[0]), vw.buf[2:], nil
}
// copyValue will copy a single value from src to dst.
func copyValue(dst ValueWriter, src ValueReader) error {
var err error
switch src.Type() {
case TypeDouble:
var f64 float64
f64, err = src.ReadDouble()
if err != nil {
break
}
err = dst.WriteDouble(f64)
case TypeString:
var str string
str, err = src.ReadString()
if err != nil {
return err
}
err = dst.WriteString(str)
case TypeEmbeddedDocument:
err = copyDocument(dst, src)
case TypeArray:
err = copyArray(dst, src)
case TypeBinary:
var data []byte
var subtype byte
data, subtype, err = src.ReadBinary()
if err != nil {
break
}
err = dst.WriteBinaryWithSubtype(data, subtype)
case TypeUndefined:
err = src.ReadUndefined()
if err != nil {
break
}
err = dst.WriteUndefined()
case TypeObjectID:
var oid ObjectID
oid, err = src.ReadObjectID()
if err != nil {
break
}
err = dst.WriteObjectID(oid)
case TypeBoolean:
var b bool
b, err = src.ReadBoolean()
if err != nil {
break
}
err = dst.WriteBoolean(b)
case TypeDateTime:
var dt int64
dt, err = src.ReadDateTime()
if err != nil {
break
}
err = dst.WriteDateTime(dt)
case TypeNull:
err = src.ReadNull()
if err != nil {
break
}
err = dst.WriteNull()
case TypeRegex:
var pattern, options string
pattern, options, err = src.ReadRegex()
if err != nil {
break
}
err = dst.WriteRegex(pattern, options)
case TypeDBPointer:
var ns string
var pointer ObjectID
ns, pointer, err = src.ReadDBPointer()
if err != nil {
break
}
err = dst.WriteDBPointer(ns, pointer)
case TypeJavaScript:
var js string
js, err = src.ReadJavascript()
if err != nil {
break
}
err = dst.WriteJavascript(js)
case TypeSymbol:
var symbol string
symbol, err = src.ReadSymbol()
if err != nil {
break
}
err = dst.WriteSymbol(symbol)
case TypeCodeWithScope:
var code string
var srcScope DocumentReader
code, srcScope, err = src.ReadCodeWithScope()
if err != nil {
break
}
var dstScope DocumentWriter
dstScope, err = dst.WriteCodeWithScope(code)
if err != nil {
break
}
err = copyDocumentCore(dstScope, srcScope)
case TypeInt32:
var i32 int32
i32, err = src.ReadInt32()
if err != nil {
break
}
err = dst.WriteInt32(i32)
case TypeTimestamp:
var t, i uint32
t, i, err = src.ReadTimestamp()
if err != nil {
break
}
err = dst.WriteTimestamp(t, i)
case TypeInt64:
var i64 int64
i64, err = src.ReadInt64()
if err != nil {
break
}
err = dst.WriteInt64(i64)
case TypeDecimal128:
var d128 Decimal128
d128, err = src.ReadDecimal128()
if err != nil {
break
}
err = dst.WriteDecimal128(d128)
case TypeMinKey:
err = src.ReadMinKey()
if err != nil {
break
}
err = dst.WriteMinKey()
case TypeMaxKey:
err = src.ReadMaxKey()
if err != nil {
break
}
err = dst.WriteMaxKey()
default:
err = fmt.Errorf("cannot copy unknown BSON type %s", src.Type())
}
return err
}
func copyArray(dst ValueWriter, src ValueReader) error {
ar, err := src.ReadArray()
if err != nil {
return err
}
aw, err := dst.WriteArray()
if err != nil {
return err
}
for {
vr, err := ar.ReadValue()
if errors.Is(err, ErrEOA) {
break
}
if err != nil {
return err
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = copyValue(vw, vr)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
for {
key, vr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
return err
}
vw, err := dw.WriteDocumentElement(key)
if err != nil {
return err
}
err = copyValue(vw, vr)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// bytesReader is the interface used to read BSON bytes from a valueReader.
//
// The bytes of the value will be appended to dst.
type bytesReader interface {
readValueBytes(dst []byte) (Type, []byte, error)
}
// bytesWriter is the interface used to write BSON bytes to a valueWriter.
type bytesWriter interface {
writeValueBytes(t Type, b []byte) error
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import (
"encoding/json"
"errors"
"fmt"
"math/big"
"regexp"
"strconv"
"strings"
"go.mongodb.org/mongo-driver/v2/internal/decimal128"
)
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
const (
MaxDecimal128Exp = 6111
MinDecimal128Exp = -6176
)
// These errors are returned when an invalid value is parsed as a big.Int.
var (
ErrParseNaN = errors.New("cannot parse NaN as a *big.Int")
ErrParseInf = errors.New("cannot parse Infinity as a *big.Int")
ErrParseNegInf = errors.New("cannot parse -Infinity as a *big.Int")
)
// Decimal128 holds decimal128 BSON values.
type Decimal128 struct {
h, l uint64
}
// NewDecimal128 creates a Decimal128 using the provide high and low uint64s.
func NewDecimal128(h, l uint64) Decimal128 {
return Decimal128{h: h, l: l}
}
// GetBytes returns the underlying bytes of the BSON decimal value as two uint64 values. The first
// contains the most first 8 bytes of the value and the second contains the latter.
func (d Decimal128) GetBytes() (uint64, uint64) {
return d.h, d.l
}
// String returns a string representation of the decimal value.
func (d Decimal128) String() string {
return decimal128.String(d.h, d.l)
}
// BigInt returns significand as big.Int and exponent, bi * 10 ^ exp.
func (d Decimal128) BigInt() (*big.Int, int, error) {
high, low := d.GetBytes()
posSign := high>>63&1 == 0 // positive sign
switch high >> 58 & (1<<5 - 1) {
case 0x1F:
return nil, 0, ErrParseNaN
case 0x1E:
if posSign {
return nil, 0, ErrParseInf
}
return nil, 0, ErrParseNegInf
}
var exp int
if high>>61&3 == 3 {
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
// Implicit 0b100 prefix in significand.
exp = int(high >> 47 & (1<<14 - 1))
// Spec says all of these values are out of range.
high, low = 0, 0
} else {
// Bits: 1*sign 14*exponent 113*significand
exp = int(high >> 49 & (1<<14 - 1))
high &= (1<<49 - 1)
}
exp += MinDecimal128Exp
// Would be handled by the logic below, but that's trivial and common.
if high == 0 && low == 0 && exp == 0 {
return new(big.Int), 0, nil
}
bi := big.NewInt(0)
const host32bit = ^uint(0)>>32 == 0
if host32bit {
bi.SetBits([]big.Word{big.Word(low), big.Word(low >> 32), big.Word(high), big.Word(high >> 32)})
} else {
bi.SetBits([]big.Word{big.Word(low), big.Word(high)})
}
if !posSign {
return bi.Neg(bi), exp, nil
}
return bi, exp, nil
}
// IsNaN returns whether d is NaN.
func (d Decimal128) IsNaN() bool {
return d.h>>58&(1<<5-1) == 0x1F
}
// IsInf returns:
//
// +1 d == Infinity
// 0 other case
// -1 d == -Infinity
func (d Decimal128) IsInf() int {
if d.h>>58&(1<<5-1) != 0x1E {
return 0
}
if d.h>>63&1 == 0 {
return 1
}
return -1
}
// IsZero returns true if d is the empty Decimal128.
func (d Decimal128) IsZero() bool {
return d.h == 0 && d.l == 0
}
// MarshalJSON returns Decimal128 as a string.
func (d Decimal128) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON creates a Decimal128 from a JSON string, an extended JSON $numberDecimal value, or the string
// "null". If b is a JSON string or extended JSON value, d will have the value of that string, and if b is "null", d will
// be unchanged.
func (d *Decimal128) UnmarshalJSON(b []byte) error {
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer Decimal128 field
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
// enter the UnmarshalJSON hook.
if string(b) == "null" {
return nil
}
var res interface{}
err := json.Unmarshal(b, &res)
if err != nil {
return err
}
str, ok := res.(string)
// Extended JSON
if !ok {
m, ok := res.(map[string]interface{})
if !ok {
return errors.New("not an extended JSON Decimal128: expected document")
}
d128, ok := m["$numberDecimal"]
if !ok {
return errors.New("not an extended JSON Decimal128: expected key $numberDecimal")
}
str, ok = d128.(string)
if !ok {
return errors.New("not an extended JSON Decimal128: expected decimal to be string")
}
}
*d, err = ParseDecimal128(str)
return err
}
var dNaN = Decimal128{0x1F << 58, 0}
var dPosInf = Decimal128{0x1E << 58, 0}
var dNegInf = Decimal128{0x3E << 58, 0}
func dErr(s string) (Decimal128, error) {
return dNaN, fmt.Errorf("cannot parse %q as a decimal128", s)
}
// match scientific notation number, example -10.15e-18
var normalNumber = regexp.MustCompile(`^(?P<int>[-+]?\d*)?(?:\.(?P<dec>\d*))?(?:[Ee](?P<exp>[-+]?\d+))?$`)
// ParseDecimal128 takes the given string and attempts to parse it into a valid
// Decimal128 value.
func ParseDecimal128(s string) (Decimal128, error) {
if s == "" {
return dErr(s)
}
matches := normalNumber.FindStringSubmatch(s)
if len(matches) == 0 {
orig := s
neg := s[0] == '-'
if neg || s[0] == '+' {
s = s[1:]
}
if s == "NaN" || s == "nan" || strings.EqualFold(s, "nan") {
return dNaN, nil
}
if s == "Inf" || s == "inf" || strings.EqualFold(s, "inf") || strings.EqualFold(s, "infinity") {
if neg {
return dNegInf, nil
}
return dPosInf, nil
}
return dErr(orig)
}
intPart := matches[1]
decPart := matches[2]
expPart := matches[3]
var err error
exp := 0
if expPart != "" {
exp, err = strconv.Atoi(expPart)
if err != nil {
return dErr(s)
}
}
if decPart != "" {
exp -= len(decPart)
}
if len(strings.Trim(intPart+decPart, "-0")) > 35 {
return dErr(s)
}
// Parse the significand (i.e. the non-exponent part) as a big.Int.
bi, ok := new(big.Int).SetString(intPart+decPart, 10)
if !ok {
return dErr(s)
}
d, ok := ParseDecimal128FromBigInt(bi, exp)
if !ok {
return dErr(s)
}
if bi.Sign() == 0 && s[0] == '-' {
d.h |= 1 << 63
}
return d, nil
}
var (
ten = big.NewInt(10)
zero = new(big.Int)
maxS, _ = new(big.Int).SetString("9999999999999999999999999999999999", 10)
)
// ParseDecimal128FromBigInt attempts to parse the given significand and exponent into a valid Decimal128 value.
func ParseDecimal128FromBigInt(bi *big.Int, exp int) (Decimal128, bool) {
// copy
bi = new(big.Int).Set(bi)
q := new(big.Int)
r := new(big.Int)
// If the significand is zero, the logical value will always be zero, independent of the
// exponent. However, the loops for handling out-of-range exponent values below may be extremely
// slow for zero values because the significand never changes. Limit the exponent value to the
// supported range here to prevent entering the loops below.
if bi.Cmp(zero) == 0 {
if exp > MaxDecimal128Exp {
exp = MaxDecimal128Exp
}
if exp < MinDecimal128Exp {
exp = MinDecimal128Exp
}
}
for bigIntCmpAbs(bi, maxS) == 1 {
bi, _ = q.QuoRem(bi, ten, r)
if r.Cmp(zero) != 0 {
return Decimal128{}, false
}
exp++
if exp > MaxDecimal128Exp {
return Decimal128{}, false
}
}
for exp < MinDecimal128Exp {
// Subnormal.
bi, _ = q.QuoRem(bi, ten, r)
if r.Cmp(zero) != 0 {
return Decimal128{}, false
}
exp++
}
for exp > MaxDecimal128Exp {
// Clamped.
bi.Mul(bi, ten)
if bigIntCmpAbs(bi, maxS) == 1 {
return Decimal128{}, false
}
exp--
}
b := bi.Bytes()
var h, l uint64
for i := 0; i < len(b); i++ {
if i < len(b)-8 {
h = h<<8 | uint64(b[i])
continue
}
l = l<<8 | uint64(b[i])
}
h |= uint64(exp-MinDecimal128Exp) & uint64(1<<14-1) << 49
if bi.Sign() == -1 {
h |= 1 << 63
}
return Decimal128{h: h, l: l}, true
}
// bigIntCmpAbs computes big.Int.Cmp(absoluteValue(x), absoluteValue(y)).
func bigIntCmpAbs(x, y *big.Int) int {
xAbs := bigIntAbsValue(x)
yAbs := bigIntAbsValue(y)
return xAbs.Cmp(yAbs)
}
// bigIntAbsValue returns a big.Int containing the absolute value of b.
// If b is already a non-negative number, it is returned without any changes or copies.
func bigIntAbsValue(b *big.Int) *big.Int {
if b.Sign() >= 0 {
return b // already positive
}
return new(big.Int).Abs(b)
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
"sync"
)
// ErrDecodeToNil is the error returned when trying to decode to a nil value
var ErrDecodeToNil = errors.New("cannot Decode to nil value")
// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Decoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var decPool = sync.Pool{
New: func() interface{} {
return new(Decoder)
},
}
// A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as
// the source of BSON data.
type Decoder struct {
dc DecodeContext
vr ValueReader
}
// NewDecoder returns a new decoder that reads from vr.
func NewDecoder(vr ValueReader) *Decoder {
return &Decoder{
dc: DecodeContext{Registry: defaultRegistry},
vr: vr,
}
}
// Decode reads the next BSON document from the stream and decodes it into the
// value pointed to by val.
//
// See [Unmarshal] for details about BSON unmarshaling behavior.
func (d *Decoder) Decode(val interface{}) error {
if unmarshaler, ok := val.(Unmarshaler); ok {
// TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method.
buf, err := copyDocumentToBytes(d.vr)
if err != nil {
return err
}
return unmarshaler.UnmarshalBSON(buf)
}
rval := reflect.ValueOf(val)
switch rval.Kind() {
case reflect.Ptr:
if rval.IsNil() {
return ErrDecodeToNil
}
rval = rval.Elem()
case reflect.Map:
if rval.IsNil() {
return ErrDecodeToNil
}
default:
return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval)
}
decoder, err := d.dc.LookupDecoder(rval.Type())
if err != nil {
return err
}
return decoder.DecodeValue(d.dc, d.vr, rval)
}
// Reset will reset the state of the decoder, using the same *DecodeContext used in
// the original construction but using vr for reading.
func (d *Decoder) Reset(vr ValueReader) {
d.vr = vr
}
// SetRegistry replaces the current registry of the decoder with r.
func (d *Decoder) SetRegistry(r *Registry) {
d.dc.Registry = r
}
// DefaultDocumentM causes the Decoder to always unmarshal documents into the bson.M type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
func (d *Decoder) DefaultDocumentM() {
d.dc.defaultDocumentType = reflect.TypeOf(M{})
}
// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values
// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct
// field. The truncation logic does not apply to BSON "decimal128" values.
func (d *Decoder) AllowTruncatingDoubles() {
d.dc.truncate = true
}
// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or
// "Old" BSON binary subtype as a Go byte slice instead of a bson.Binary.
func (d *Decoder) BinaryAsSlice() {
d.dc.binaryAsSlice = true
}
// ObjectIDAsHexString causes the Decoder to decode object IDs to their hex representation.
func (d *Decoder) ObjectIDAsHexString() {
d.dc.objectIDAsHexString = true
}
// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson"
// struct tag is not specified.
func (d *Decoder) UseJSONStructTags() {
d.dc.useJSONStructTags = true
}
// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead
// of the UTC timezone.
func (d *Decoder) UseLocalTimeZone() {
d.dc.useLocalTimeZone = true
}
// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value
// passed to Decode before unmarshaling BSON documents into them.
func (d *Decoder) ZeroMaps() {
d.dc.zeroMaps = true
}
// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination
// value passed to Decode before unmarshaling BSON documents into them.
func (d *Decoder) ZeroStructs() {
d.dc.zeroStructs = true
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/json"
"errors"
"fmt"
"math"
"net/url"
"reflect"
"strconv"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
var errCannotTruncate = errors.New("float64 can only be truncated to a lower precision type when truncation is enabled")
type decodeBinaryError struct {
subtype byte
typeName string
}
func (d decodeBinaryError) Error() string {
return fmt.Sprintf("only binary values with subtype 0x00 or 0x02 can be decoded into %s, but got subtype %v", d.typeName, d.subtype)
}
// registerDefaultDecoders will register the decoder methods attached to DefaultValueDecoders with
// the provided RegistryBuilder.
//
// There is no support for decoding map[string]interface{} because there is no decoder for
// interface{}, so users must either register this decoder themselves or use the
// EmptyInterfaceDecoder available in the bson package.
func registerDefaultDecoders(reg *Registry) {
intDecoder := decodeAdapter{intDecodeValue, intDecodeType}
floatDecoder := decodeAdapter{floatDecodeValue, floatDecodeType}
uintCodec := &uintCodec{}
reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue))
reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType})
reg.RegisterTypeDecoder(tVector, decodeAdapter{vectorDecodeValue, vectorDecodeType})
reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType})
reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType})
reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType})
reg.RegisterTypeDecoder(tRegex, decodeAdapter{regexDecodeValue, regexDecodeType})
reg.RegisterTypeDecoder(tDBPointer, decodeAdapter{dbPointerDecodeValue, dbPointerDecodeType})
reg.RegisterTypeDecoder(tTimestamp, decodeAdapter{timestampDecodeValue, timestampDecodeType})
reg.RegisterTypeDecoder(tMinKey, decodeAdapter{minKeyDecodeValue, minKeyDecodeType})
reg.RegisterTypeDecoder(tMaxKey, decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType})
reg.RegisterTypeDecoder(tJavaScript, decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType})
reg.RegisterTypeDecoder(tSymbol, decodeAdapter{symbolDecodeValue, symbolDecodeType})
reg.RegisterTypeDecoder(tByteSlice, &byteSliceCodec{})
reg.RegisterTypeDecoder(tTime, &timeCodec{})
reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{})
reg.RegisterTypeDecoder(tCoreArray, &arrayCodec{})
reg.RegisterTypeDecoder(tOID, decodeAdapter{objectIDDecodeValue, objectIDDecodeType})
reg.RegisterTypeDecoder(tDecimal, decodeAdapter{decimal128DecodeValue, decimal128DecodeType})
reg.RegisterTypeDecoder(tJSONNumber, decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType})
reg.RegisterTypeDecoder(tURL, decodeAdapter{urlDecodeValue, urlDecodeType})
reg.RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(coreDocumentDecodeValue))
reg.RegisterTypeDecoder(tCodeWithScope, decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType})
reg.RegisterKindDecoder(reflect.Bool, decodeAdapter{booleanDecodeValue, booleanDecodeType})
reg.RegisterKindDecoder(reflect.Int, intDecoder)
reg.RegisterKindDecoder(reflect.Int8, intDecoder)
reg.RegisterKindDecoder(reflect.Int16, intDecoder)
reg.RegisterKindDecoder(reflect.Int32, intDecoder)
reg.RegisterKindDecoder(reflect.Int64, intDecoder)
reg.RegisterKindDecoder(reflect.Uint, uintCodec)
reg.RegisterKindDecoder(reflect.Uint8, uintCodec)
reg.RegisterKindDecoder(reflect.Uint16, uintCodec)
reg.RegisterKindDecoder(reflect.Uint32, uintCodec)
reg.RegisterKindDecoder(reflect.Uint64, uintCodec)
reg.RegisterKindDecoder(reflect.Float32, floatDecoder)
reg.RegisterKindDecoder(reflect.Float64, floatDecoder)
reg.RegisterKindDecoder(reflect.Array, ValueDecoderFunc(arrayDecodeValue))
reg.RegisterKindDecoder(reflect.Map, &mapCodec{})
reg.RegisterKindDecoder(reflect.Slice, &sliceCodec{})
reg.RegisterKindDecoder(reflect.String, &stringCodec{})
reg.RegisterKindDecoder(reflect.Struct, newStructCodec(nil))
reg.RegisterKindDecoder(reflect.Ptr, &pointerCodec{})
reg.RegisterTypeMapEntry(TypeDouble, tFloat64)
reg.RegisterTypeMapEntry(TypeString, tString)
reg.RegisterTypeMapEntry(TypeArray, tA)
reg.RegisterTypeMapEntry(TypeBinary, tBinary)
reg.RegisterTypeMapEntry(TypeUndefined, tUndefined)
reg.RegisterTypeMapEntry(TypeObjectID, tOID)
reg.RegisterTypeMapEntry(TypeBoolean, tBool)
reg.RegisterTypeMapEntry(TypeDateTime, tDateTime)
reg.RegisterTypeMapEntry(TypeRegex, tRegex)
reg.RegisterTypeMapEntry(TypeDBPointer, tDBPointer)
reg.RegisterTypeMapEntry(TypeJavaScript, tJavaScript)
reg.RegisterTypeMapEntry(TypeSymbol, tSymbol)
reg.RegisterTypeMapEntry(TypeCodeWithScope, tCodeWithScope)
reg.RegisterTypeMapEntry(TypeInt32, tInt32)
reg.RegisterTypeMapEntry(TypeInt64, tInt64)
reg.RegisterTypeMapEntry(TypeTimestamp, tTimestamp)
reg.RegisterTypeMapEntry(TypeDecimal128, tDecimal)
reg.RegisterTypeMapEntry(TypeMinKey, tMinKey)
reg.RegisterTypeMapEntry(TypeMaxKey, tMaxKey)
reg.RegisterTypeMapEntry(Type(0), tD)
reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tD)
reg.RegisterInterfaceDecoder(tValueUnmarshaler, ValueDecoderFunc(valueUnmarshalerDecodeValue))
reg.RegisterInterfaceDecoder(tUnmarshaler, ValueDecoderFunc(unmarshalerDecodeValue))
}
// dDecodeValue is the ValueDecoderFunc for D instances.
func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.IsValid() || !val.CanSet() || val.Type() != tD {
return ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
switch vrType := vr.Type(); vrType {
case Type(0), TypeEmbeddedDocument:
break
case TypeNull:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
default:
return fmt.Errorf("cannot decode %v into a D", vrType)
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
decoder, err := dc.LookupDecoder(tEmpty)
if err != nil {
return err
}
// Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance.
var elems D
if !val.IsNil() {
val.SetLen(0)
elems = val.Interface().(D)
} else {
elems = make(D, 0)
}
for {
key, elemVr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
} else if err != nil {
return err
}
var v interface{}
err = decoder.DecodeValue(dc, elemVr, reflect.ValueOf(&v).Elem())
if err != nil {
return err
}
elems = append(elems, E{Key: key, Value: v})
}
val.Set(reflect.ValueOf(elems))
return nil
}
func booleanDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t.Kind() != reflect.Bool {
return emptyValue, ValueDecoderError{
Name: "BooleanDecodeValue",
Kinds: []reflect.Kind{reflect.Bool},
Received: reflect.Zero(t),
}
}
var b bool
var err error
switch vrType := vr.Type(); vrType {
case TypeInt32:
i32, err := vr.ReadInt32()
if err != nil {
return emptyValue, err
}
b = (i32 != 0)
case TypeInt64:
i64, err := vr.ReadInt64()
if err != nil {
return emptyValue, err
}
b = (i64 != 0)
case TypeDouble:
f64, err := vr.ReadDouble()
if err != nil {
return emptyValue, err
}
b = (f64 != 0)
case TypeBoolean:
b, err = vr.ReadBoolean()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a boolean", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(b), nil
}
// booleanDecodeValue is the ValueDecoderFunc for bool types.
func booleanDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool {
return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
}
elem, err := booleanDecodeType(dctx, vr, val.Type())
if err != nil {
return err
}
val.SetBool(elem.Bool())
return nil
}
func intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
var i64 int64
var err error
switch vrType := vr.Type(); vrType {
case TypeInt32:
i32, err := vr.ReadInt32()
if err != nil {
return emptyValue, err
}
i64 = int64(i32)
case TypeInt64:
i64, err = vr.ReadInt64()
if err != nil {
return emptyValue, err
}
case TypeDouble:
f64, err := vr.ReadDouble()
if err != nil {
return emptyValue, err
}
if !dc.truncate && math.Floor(f64) != f64 {
return emptyValue, errCannotTruncate
}
if f64 > float64(math.MaxInt64) {
return emptyValue, fmt.Errorf("%g overflows int64", f64)
}
i64 = int64(f64)
case TypeBoolean:
b, err := vr.ReadBoolean()
if err != nil {
return emptyValue, err
}
if b {
i64 = 1
}
case TypeNull:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType)
}
switch t.Kind() {
case reflect.Int8:
if i64 < math.MinInt8 || i64 > math.MaxInt8 {
return emptyValue, fmt.Errorf("%d overflows int8", i64)
}
return reflect.ValueOf(int8(i64)), nil
case reflect.Int16:
if i64 < math.MinInt16 || i64 > math.MaxInt16 {
return emptyValue, fmt.Errorf("%d overflows int16", i64)
}
return reflect.ValueOf(int16(i64)), nil
case reflect.Int32:
if i64 < math.MinInt32 || i64 > math.MaxInt32 {
return emptyValue, fmt.Errorf("%d overflows int32", i64)
}
return reflect.ValueOf(int32(i64)), nil
case reflect.Int64:
return reflect.ValueOf(i64), nil
case reflect.Int:
if i64 > math.MaxInt { // Can we fit this inside of an int
return emptyValue, fmt.Errorf("%d overflows int", i64)
}
return reflect.ValueOf(int(i64)), nil
default:
return emptyValue, ValueDecoderError{
Name: "IntDecodeValue",
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
Received: reflect.Zero(t),
}
}
}
// intDecodeValue is the ValueDecoderFunc for int types.
func intDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() {
return ValueDecoderError{
Name: "IntDecodeValue",
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
Received: val,
}
}
elem, err := intDecodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.SetInt(elem.Int())
return nil
}
func floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
var f float64
var err error
switch vrType := vr.Type(); vrType {
case TypeInt32:
i32, err := vr.ReadInt32()
if err != nil {
return emptyValue, err
}
f = float64(i32)
case TypeInt64:
i64, err := vr.ReadInt64()
if err != nil {
return emptyValue, err
}
f = float64(i64)
case TypeDouble:
f, err = vr.ReadDouble()
if err != nil {
return emptyValue, err
}
case TypeBoolean:
b, err := vr.ReadBoolean()
if err != nil {
return emptyValue, err
}
if b {
f = 1
}
case TypeNull:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType)
}
switch t.Kind() {
case reflect.Float32:
if !dc.truncate && float64(float32(f)) != f {
return emptyValue, errCannotTruncate
}
return reflect.ValueOf(float32(f)), nil
case reflect.Float64:
return reflect.ValueOf(f), nil
default:
return emptyValue, ValueDecoderError{
Name: "FloatDecodeValue",
Kinds: []reflect.Kind{reflect.Float32, reflect.Float64},
Received: reflect.Zero(t),
}
}
}
// floatDecodeValue is the ValueDecoderFunc for float types.
func floatDecodeValue(ec DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() {
return ValueDecoderError{
Name: "FloatDecodeValue",
Kinds: []reflect.Kind{reflect.Float32, reflect.Float64},
Received: val,
}
}
elem, err := floatDecodeType(ec, vr, val.Type())
if err != nil {
return err
}
val.SetFloat(elem.Float())
return nil
}
func javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tJavaScript {
return emptyValue, ValueDecoderError{
Name: "JavaScriptDecodeValue",
Types: []reflect.Type{tJavaScript},
Received: reflect.Zero(t),
}
}
var js string
var err error
switch vrType := vr.Type(); vrType {
case TypeJavaScript:
js, err = vr.ReadJavascript()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a JavaScript", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(JavaScript(js)), nil
}
// javaScriptDecodeValue is the ValueDecoderFunc for the JavaScript type.
func javaScriptDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tJavaScript {
return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
}
elem, err := javaScriptDecodeType(dctx, vr, tJavaScript)
if err != nil {
return err
}
val.SetString(elem.String())
return nil
}
func symbolDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tSymbol {
return emptyValue, ValueDecoderError{
Name: "SymbolDecodeValue",
Types: []reflect.Type{tSymbol},
Received: reflect.Zero(t),
}
}
var symbol string
var err error
switch vrType := vr.Type(); vrType {
case TypeString:
symbol, err = vr.ReadString()
case TypeSymbol:
symbol, err = vr.ReadSymbol()
case TypeBinary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "Symbol"}
}
symbol = string(data)
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a Symbol", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(Symbol(symbol)), nil
}
// symbolDecodeValue is the ValueDecoderFunc for the Symbol type.
func symbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tSymbol {
return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val}
}
elem, err := symbolDecodeType(dctx, vr, tSymbol)
if err != nil {
return err
}
val.SetString(elem.String())
return nil
}
func binaryDecode(vr ValueReader) (Binary, error) {
var b Binary
var data []byte
var subtype byte
var err error
switch vrType := vr.Type(); vrType {
case TypeBinary:
data, subtype, err = vr.ReadBinary()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return b, fmt.Errorf("cannot decode %v into a Binary", vrType)
}
if err != nil {
return b, err
}
b.Subtype = subtype
b.Data = data
return b, nil
}
func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tBinary {
return emptyValue, ValueDecoderError{
Name: "BinaryDecodeValue",
Types: []reflect.Type{tBinary},
Received: reflect.Zero(t),
}
}
b, err := binaryDecode(vr)
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(b), nil
}
// binaryDecodeValue is the ValueDecoderFunc for Binary.
func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tBinary {
return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val}
}
elem, err := binaryDecodeType(dc, vr, tBinary)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func vectorDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tVector {
return emptyValue, ValueDecoderError{
Name: "VectorDecodeValue",
Types: []reflect.Type{tVector},
Received: reflect.Zero(t),
}
}
b, err := binaryDecode(vr)
if err != nil {
return emptyValue, err
}
v, err := NewVectorFromBinary(b)
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(v), nil
}
// vectorDecodeValue is the ValueDecoderFunc for Vector.
func vectorDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
t := val.Type()
if !val.CanSet() || t != tVector {
return ValueDecoderError{
Name: "VectorDecodeValue",
Types: []reflect.Type{tVector},
Received: val,
}
}
elem, err := vectorDecodeType(dctx, vr, t)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tUndefined {
return emptyValue, ValueDecoderError{
Name: "UndefinedDecodeValue",
Types: []reflect.Type{tUndefined},
Received: reflect.Zero(t),
}
}
var err error
switch vrType := vr.Type(); vrType {
case TypeUndefined:
err = vr.ReadUndefined()
case TypeNull:
err = vr.ReadNull()
default:
return emptyValue, fmt.Errorf("cannot decode %v into an Undefined", vr.Type())
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(Undefined{}), nil
}
// undefinedDecodeValue is the ValueDecoderFunc for Undefined.
func undefinedDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tUndefined {
return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val}
}
elem, err := undefinedDecodeType(dc, vr, tUndefined)
if err != nil {
return err
}
val.Set(elem)
return nil
}
// Accept both 12-byte string and pretty-printed 24-byte hex string formats.
func objectIDDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tOID {
return emptyValue, ValueDecoderError{
Name: "ObjectIDDecodeValue",
Types: []reflect.Type{tOID},
Received: reflect.Zero(t),
}
}
var oid ObjectID
var err error
switch vrType := vr.Type(); vrType {
case TypeObjectID:
oid, err = vr.ReadObjectID()
if err != nil {
return emptyValue, err
}
case TypeString:
str, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
if oid, err = ObjectIDFromHex(str); err == nil {
break
}
if len(str) != 12 {
return emptyValue, fmt.Errorf("an ObjectID string must be exactly 12 bytes long (got %v)", len(str))
}
byteArr := []byte(str)
copy(oid[:], byteArr)
case TypeNull:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into an ObjectID", vrType)
}
return reflect.ValueOf(oid), nil
}
// objectIDDecodeValue is the ValueDecoderFunc for ObjectID.
func objectIDDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tOID {
return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val}
}
elem, err := objectIDDecodeType(dc, vr, tOID)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func dateTimeDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tDateTime {
return emptyValue, ValueDecoderError{
Name: "DateTimeDecodeValue",
Types: []reflect.Type{tDateTime},
Received: reflect.Zero(t),
}
}
var dt int64
var err error
switch vrType := vr.Type(); vrType {
case TypeDateTime:
dt, err = vr.ReadDateTime()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a DateTime", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(DateTime(dt)), nil
}
// dateTimeDecodeValue is the ValueDecoderFunc for DateTime.
func dateTimeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tDateTime {
return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val}
}
elem, err := dateTimeDecodeType(dc, vr, tDateTime)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func nullDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tNull {
return emptyValue, ValueDecoderError{
Name: "NullDecodeValue",
Types: []reflect.Type{tNull},
Received: reflect.Zero(t),
}
}
var err error
switch vrType := vr.Type(); vrType {
case TypeUndefined:
err = vr.ReadUndefined()
case TypeNull:
err = vr.ReadNull()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a Null", vr.Type())
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(Null{}), nil
}
// nullDecodeValue is the ValueDecoderFunc for Null.
func nullDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tNull {
return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val}
}
elem, err := nullDecodeType(dc, vr, tNull)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func regexDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tRegex {
return emptyValue, ValueDecoderError{
Name: "RegexDecodeValue",
Types: []reflect.Type{tRegex},
Received: reflect.Zero(t),
}
}
var pattern, options string
var err error
switch vrType := vr.Type(); vrType {
case TypeRegex:
pattern, options, err = vr.ReadRegex()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a Regex", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(Regex{Pattern: pattern, Options: options}), nil
}
// regexDecodeValue is the ValueDecoderFunc for Regex.
func regexDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tRegex {
return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val}
}
elem, err := regexDecodeType(dc, vr, tRegex)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func dbPointerDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tDBPointer {
return emptyValue, ValueDecoderError{
Name: "DBPointerDecodeValue",
Types: []reflect.Type{tDBPointer},
Received: reflect.Zero(t),
}
}
var ns string
var pointer ObjectID
var err error
switch vrType := vr.Type(); vrType {
case TypeDBPointer:
ns, pointer, err = vr.ReadDBPointer()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a DBPointer", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(DBPointer{DB: ns, Pointer: pointer}), nil
}
// dbPointerDecodeValue is the ValueDecoderFunc for DBPointer.
func dbPointerDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tDBPointer {
return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
}
elem, err := dbPointerDecodeType(dc, vr, tDBPointer)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func timestampDecodeType(_ DecodeContext, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) {
if reflectType != tTimestamp {
return emptyValue, ValueDecoderError{
Name: "TimestampDecodeValue",
Types: []reflect.Type{tTimestamp},
Received: reflect.Zero(reflectType),
}
}
var t, incr uint32
var err error
switch vrType := vr.Type(); vrType {
case TypeTimestamp:
t, incr, err = vr.ReadTimestamp()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a Timestamp", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(Timestamp{T: t, I: incr}), nil
}
// timestampDecodeValue is the ValueDecoderFunc for Timestamp.
func timestampDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tTimestamp {
return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
}
elem, err := timestampDecodeType(dc, vr, tTimestamp)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func minKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tMinKey {
return emptyValue, ValueDecoderError{
Name: "MinKeyDecodeValue",
Types: []reflect.Type{tMinKey},
Received: reflect.Zero(t),
}
}
var err error
switch vrType := vr.Type(); vrType {
case TypeMinKey:
err = vr.ReadMinKey()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a MinKey", vr.Type())
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(MinKey{}), nil
}
// minKeyDecodeValue is the ValueDecoderFunc for MinKey.
func minKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tMinKey {
return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val}
}
elem, err := minKeyDecodeType(dc, vr, tMinKey)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func maxKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tMaxKey {
return emptyValue, ValueDecoderError{
Name: "MaxKeyDecodeValue",
Types: []reflect.Type{tMaxKey},
Received: reflect.Zero(t),
}
}
var err error
switch vrType := vr.Type(); vrType {
case TypeMaxKey:
err = vr.ReadMaxKey()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a MaxKey", vr.Type())
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(MaxKey{}), nil
}
// maxKeyDecodeValue is the ValueDecoderFunc for MaxKey.
func maxKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tMaxKey {
return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
}
elem, err := maxKeyDecodeType(dc, vr, tMaxKey)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func decimal128DecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tDecimal {
return emptyValue, ValueDecoderError{
Name: "Decimal128DecodeValue",
Types: []reflect.Type{tDecimal},
Received: reflect.Zero(t),
}
}
var d128 Decimal128
var err error
switch vrType := vr.Type(); vrType {
case TypeDecimal128:
d128, err = vr.ReadDecimal128()
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a Decimal128", vr.Type())
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(d128), nil
}
// decimal128DecodeValue is the ValueDecoderFunc for Decimal128.
func decimal128DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tDecimal {
return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val}
}
elem, err := decimal128DecodeType(dctx, vr, tDecimal)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func jsonNumberDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tJSONNumber {
return emptyValue, ValueDecoderError{
Name: "JSONNumberDecodeValue",
Types: []reflect.Type{tJSONNumber},
Received: reflect.Zero(t),
}
}
var jsonNum json.Number
var err error
switch vrType := vr.Type(); vrType {
case TypeDouble:
f64, err := vr.ReadDouble()
if err != nil {
return emptyValue, err
}
jsonNum = json.Number(strconv.FormatFloat(f64, 'f', -1, 64))
case TypeInt32:
i32, err := vr.ReadInt32()
if err != nil {
return emptyValue, err
}
jsonNum = json.Number(strconv.FormatInt(int64(i32), 10))
case TypeInt64:
i64, err := vr.ReadInt64()
if err != nil {
return emptyValue, err
}
jsonNum = json.Number(strconv.FormatInt(i64, 10))
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a json.Number", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(jsonNum), nil
}
// jsonNumberDecodeValue is the ValueDecoderFunc for json.Number.
func jsonNumberDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tJSONNumber {
return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
}
elem, err := jsonNumberDecodeType(dc, vr, tJSONNumber)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func urlDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tURL {
return emptyValue, ValueDecoderError{
Name: "URLDecodeValue",
Types: []reflect.Type{tURL},
Received: reflect.Zero(t),
}
}
urlPtr := &url.URL{}
var err error
switch vrType := vr.Type(); vrType {
case TypeString:
var str string // Declare str here to avoid shadowing err during the ReadString call.
str, err = vr.ReadString()
if err != nil {
return emptyValue, err
}
urlPtr, err = url.Parse(str)
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a *url.URL", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(urlPtr).Elem(), nil
}
// urlDecodeValue is the ValueDecoderFunc for url.URL.
func urlDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tURL {
return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val}
}
elem, err := urlDecodeType(dc, vr, tURL)
if err != nil {
return err
}
val.Set(elem)
return nil
}
// arrayDecodeValue is the ValueDecoderFunc for array types.
func arrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Array {
return ValueDecoderError{Name: "ArrayDecodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
}
switch vrType := vr.Type(); vrType {
case TypeArray:
case Type(0), TypeEmbeddedDocument:
if val.Type().Elem() != tE {
return fmt.Errorf("cannot decode document into %s", val.Type())
}
case TypeBinary:
if val.Type().Elem() != tByte {
return fmt.Errorf("ArrayDecodeValue can only be used to decode binary into a byte array, got %v", vrType)
}
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return fmt.Errorf("ArrayDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", TypeBinary, subtype)
}
if len(data) > val.Len() {
return fmt.Errorf("more elements returned in array than can fit inside %s", val.Type())
}
for idx, elem := range data {
val.Index(idx).Set(reflect.ValueOf(elem))
}
return nil
case TypeNull:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case TypeUndefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
default:
return fmt.Errorf("cannot decode %v into an array", vrType)
}
var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error)
switch val.Type().Elem() {
case tE:
elemsFunc = decodeD
default:
elemsFunc = decodeDefault
}
elems, err := elemsFunc(dc, vr, val)
if err != nil {
return err
}
if len(elems) > val.Len() {
return fmt.Errorf("more elements returned in array than can fit inside %s, got %v elements", val.Type(), len(elems))
}
for idx, elem := range elems {
val.Index(idx).Set(elem)
}
return nil
}
// valueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations.
func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.IsValid() || (!val.Type().Implements(tValueUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tValueUnmarshaler)) {
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
}
// If BSON value is null and the go value is a pointer, then don't call
// UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e.,
// non-nil), encountering null in BSON will result in the pointer being
// directly set to nil here. Since the pointer is being replaced with nil,
// there is no opportunity (or reason) for the custom UnmarshalBSONValue logic
// to be called.
if vr.Type() == TypeNull && val.Kind() == reflect.Ptr {
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
}
if val.Kind() == reflect.Ptr && val.IsNil() {
if !val.CanSet() {
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
}
val.Set(reflect.New(val.Type().Elem()))
}
if !val.Type().Implements(tValueUnmarshaler) {
if !val.CanAddr() {
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
}
val = val.Addr() // If the type doesn't implement the interface, a pointer to it must.
}
t, src, err := copyValueToBytes(vr)
if err != nil {
return err
}
m, ok := val.Interface().(ValueUnmarshaler)
if !ok {
// NB: this error should be unreachable due to the above checks
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
}
return m.UnmarshalBSONValue(byte(t), src)
}
// unmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations.
func unmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.IsValid() || (!val.Type().Implements(tUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tUnmarshaler)) {
return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val}
}
if val.Kind() == reflect.Ptr && val.IsNil() {
if !val.CanSet() {
return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val}
}
val.Set(reflect.New(val.Type().Elem()))
}
_, src, err := copyValueToBytes(vr)
if err != nil {
return err
}
// If the target Go value is a pointer and the BSON field value is empty, set the value to the
// zero value of the pointer (nil) and don't call UnmarshalBSON. UnmarshalBSON has no way to
// change the pointer value from within the function (only the value at the pointer address),
// so it can't set the pointer to "nil" itself. Since the most common Go value for an empty BSON
// field value is "nil", we set "nil" here and don't call UnmarshalBSON. This behavior matches
// the behavior of the Go "encoding/json" unmarshaler when the target Go value is a pointer and
// the JSON field value is "null".
if val.Kind() == reflect.Ptr && len(src) == 0 {
val.Set(reflect.Zero(val.Type()))
return nil
}
if !val.Type().Implements(tUnmarshaler) {
if !val.CanAddr() {
return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val}
}
val = val.Addr() // If the type doesn't implement the interface, a pointer to it must.
}
m, ok := val.Interface().(Unmarshaler)
if !ok {
// NB: this error should be unreachable due to the above checks
return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val}
}
return m.UnmarshalBSON(src)
}
// coreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document.
func coreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tCoreDocument {
return ValueDecoderError{Name: "CoreDocumentDecodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
cdoc, err := appendDocumentBytes(val.Interface().(bsoncore.Document), vr)
val.Set(reflect.ValueOf(cdoc))
return err
}
func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]reflect.Value, error) {
elems := make([]reflect.Value, 0)
ar, err := vr.ReadArray()
if err != nil {
return nil, err
}
eType := val.Type().Elem()
var vDecoder ValueDecoder
if !(eType.Kind() == reflect.Interface && val.Len() > 0) {
vDecoder, err = dc.LookupDecoder(eType)
if err != nil {
return nil, err
}
}
idx := 0
for {
vr, err := ar.ReadValue()
if errors.Is(err, ErrEOA) {
break
}
if err != nil {
return nil, err
}
var elem reflect.Value
if vDecoder == nil {
elem = val.Index(idx).Elem()
switch {
case elem.Kind() != reflect.Ptr || elem.IsNil():
valueDecoder, err := dc.LookupDecoder(elem.Type())
if err != nil {
return nil, err
}
err = valueDecoder.DecodeValue(dc, vr, elem)
if err != nil {
return nil, newDecodeError(strconv.Itoa(idx), err)
}
case vr.Type() == TypeNull:
if err = vr.ReadNull(); err != nil {
return nil, err
}
elem = reflect.Zero(val.Index(idx).Type())
default:
e := elem.Elem()
valueDecoder, err := dc.LookupDecoder(e.Type())
if err != nil {
return nil, err
}
err = valueDecoder.DecodeValue(dc, vr, e)
if err != nil {
return nil, newDecodeError(strconv.Itoa(idx), err)
}
}
} else {
elem, err = decodeTypeOrValueWithInfo(vDecoder, dc, vr, eType)
if err != nil {
return nil, newDecodeError(strconv.Itoa(idx), err)
}
}
elems = append(elems, elem)
idx++
}
return elems, nil
}
func codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tCodeWithScope {
return emptyValue, ValueDecoderError{
Name: "CodeWithScopeDecodeValue",
Types: []reflect.Type{tCodeWithScope},
Received: reflect.Zero(t),
}
}
var cws CodeWithScope
var err error
switch vrType := vr.Type(); vrType {
case TypeCodeWithScope:
code, dr, err := vr.ReadCodeWithScope()
if err != nil {
return emptyValue, err
}
scope := reflect.New(tD).Elem()
elems, err := decodeElemsFromDocumentReader(dc, dr)
if err != nil {
return emptyValue, err
}
scope.Set(reflect.MakeSlice(tD, 0, len(elems)))
scope.Set(reflect.Append(scope, elems...))
cws = CodeWithScope{
Code: JavaScript(code),
Scope: scope.Interface().(D),
}
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a CodeWithScope", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(cws), nil
}
// codeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope.
func codeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tCodeWithScope {
return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
}
elem, err := codeWithScopeDecodeType(dc, vr, tCodeWithScope)
if err != nil {
return err
}
val.Set(elem)
return nil
}
func decodeD(dc DecodeContext, vr ValueReader, _ reflect.Value) ([]reflect.Value, error) {
switch vr.Type() {
case Type(0), TypeEmbeddedDocument:
default:
return nil, fmt.Errorf("cannot decode %v into a D", vr.Type())
}
dr, err := vr.ReadDocument()
if err != nil {
return nil, err
}
return decodeElemsFromDocumentReader(dc, dr)
}
func decodeElemsFromDocumentReader(dc DecodeContext, dr DocumentReader) ([]reflect.Value, error) {
decoder, err := dc.LookupDecoder(tEmpty)
if err != nil {
return nil, err
}
elems := make([]reflect.Value, 0)
for {
key, vr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
return nil, err
}
val := reflect.New(tEmpty).Elem()
err = decoder.DecodeValue(dc, vr, val)
if err != nil {
return nil, newDecodeError(key, err)
}
elems = append(elems, reflect.ValueOf(E{Key: key, Value: val.Interface()}))
}
return elems, nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/json"
"errors"
"math"
"net/url"
"reflect"
"sync"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
var bvwPool = sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
}
var errInvalidValue = errors.New("cannot encode invalid element")
var sliceWriterPool = sync.Pool{
New: func() interface{} {
sw := make(sliceWriter, 0)
return &sw
},
}
func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error {
vw, err := dw.WriteDocumentElement(e.Key)
if err != nil {
return err
}
if e.Value == nil {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value))
if err != nil {
return err
}
return nil
}
// registerDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with
// the provided RegistryBuilder.
func registerDefaultEncoders(reg *Registry) {
mapEncoder := &mapCodec{}
uintCodec := &uintCodec{}
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{})
reg.RegisterTypeEncoder(tTime, &timeCodec{})
reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{})
reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{})
reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue))
reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue))
reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue))
reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue))
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue))
reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue))
reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue))
reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue))
reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue))
reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue))
reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue))
reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue))
reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Uint, uintCodec)
reg.RegisterKindEncoder(reflect.Uint8, uintCodec)
reg.RegisterKindEncoder(reflect.Uint16, uintCodec)
reg.RegisterKindEncoder(reflect.Uint32, uintCodec)
reg.RegisterKindEncoder(reflect.Uint64, uintCodec)
reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue))
reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue))
reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue))
reg.RegisterKindEncoder(reflect.Map, mapEncoder)
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{})
reg.RegisterKindEncoder(reflect.String, &stringCodec{})
reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder))
reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{})
reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue))
reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue))
}
// booleanEncodeValue is the ValueEncoderFunc for bool types.
func booleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Bool {
return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
}
return vw.WriteBoolean(val.Bool())
}
func fitsIn32Bits(i int64) bool {
return math.MinInt32 <= i && i <= math.MaxInt32
}
// intEncodeValue is the ValueEncoderFunc for int types.
func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32:
return vw.WriteInt32(int32(val.Int()))
case reflect.Int:
i64 := val.Int()
if fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
case reflect.Int64:
i64 := val.Int()
if ec.minSize && fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
}
return ValueEncoderError{
Name: "IntEncodeValue",
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
Received: val,
}
}
// floatEncodeValue is the ValueEncoderFunc for float types.
func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Float32, reflect.Float64:
return vw.WriteDouble(val.Float())
}
return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val}
}
// objectIDEncodeValue is the ValueEncoderFunc for ObjectID.
func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tOID {
return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val}
}
return vw.WriteObjectID(val.Interface().(ObjectID))
}
// decimal128EncodeValue is the ValueEncoderFunc for Decimal128.
func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDecimal {
return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val}
}
return vw.WriteDecimal128(val.Interface().(Decimal128))
}
// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number.
func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJSONNumber {
return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
}
jsnum := val.Interface().(json.Number)
// Attempt int first, then float64
if i64, err := jsnum.Int64(); err == nil {
return intEncodeValue(ec, vw, reflect.ValueOf(i64))
}
f64, err := jsnum.Float64()
if err != nil {
return err
}
return floatEncodeValue(ec, vw, reflect.ValueOf(f64))
}
// urlEncodeValue is the ValueEncoderFunc for url.URL.
func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tURL {
return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val}
}
u := val.Interface().(url.URL)
return vw.WriteString(u.String())
}
// arrayEncodeValue is the ValueEncoderFunc for array types.
func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Array {
return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
}
// If we have a []E we want to treat it as a document instead of as an array.
if val.Type().Elem() == tE {
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for idx := 0; idx < val.Len(); idx++ {
e := val.Index(idx).Interface().(E)
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
var byteSlice []byte
for idx := 0; idx < val.Len(); idx++ {
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
}
return vw.WriteBinary(byteSlice)
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) {
if origEncoder != nil || (currVal.Kind() != reflect.Interface) {
return origEncoder, currVal, nil
}
currVal = currVal.Elem()
if !currVal.IsValid() {
return nil, currVal, errInvalidValue
}
currEncoder, err := ec.LookupEncoder(currVal.Type())
return currEncoder, currVal, err
}
// valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations.
func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement ValueMarshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
case val.Type().Implements(tValueMarshaler):
// If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tValueMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
}
m, ok := val.Interface().(ValueMarshaler)
if !ok {
return vw.WriteNull()
}
t, data, err := m.MarshalBSONValue()
if err != nil {
return err
}
return copyValueFromBytes(vw, Type(t), data)
}
// marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Marshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
case val.Type().Implements(tMarshaler):
// If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
}
m, ok := val.Interface().(Marshaler)
if !ok {
return vw.WriteNull()
}
data, err := m.MarshalBSON()
if err != nil {
return err
}
return copyValueFromBytes(vw, TypeEmbeddedDocument, data)
}
// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type.
func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJavaScript {
return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
}
return vw.WriteJavascript(val.String())
}
// symbolEncodeValue is the ValueEncoderFunc for the Symbol type.
func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tSymbol {
return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val}
}
return vw.WriteSymbol(val.String())
}
// binaryEncodeValue is the ValueEncoderFunc for Binary.
func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tBinary {
return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val}
}
b := val.Interface().(Binary)
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// vectorEncodeValue is the ValueEncoderFunc for Vector.
func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
t := val.Type()
if !val.IsValid() || t != tVector {
return ValueEncoderError{Name: "VectorEncodeValue",
Types: []reflect.Type{tVector},
Received: val,
}
}
v := val.Interface().(Vector)
b := v.Binary()
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// undefinedEncodeValue is the ValueEncoderFunc for Undefined.
func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tUndefined {
return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val}
}
return vw.WriteUndefined()
}
// dateTimeEncodeValue is the ValueEncoderFunc for DateTime.
func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDateTime {
return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val}
}
return vw.WriteDateTime(val.Int())
}
// nullEncodeValue is the ValueEncoderFunc for Null.
func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tNull {
return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val}
}
return vw.WriteNull()
}
// regexEncodeValue is the ValueEncoderFunc for Regex.
func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRegex {
return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val}
}
regex := val.Interface().(Regex)
return vw.WriteRegex(regex.Pattern, regex.Options)
}
// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer.
func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDBPointer {
return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
}
dbp := val.Interface().(DBPointer)
return vw.WriteDBPointer(dbp.DB, dbp.Pointer)
}
// timestampEncodeValue is the ValueEncoderFunc for Timestamp.
func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTimestamp {
return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
}
ts := val.Interface().(Timestamp)
return vw.WriteTimestamp(ts.T, ts.I)
}
// minKeyEncodeValue is the ValueEncoderFunc for MinKey.
func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMinKey {
return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val}
}
return vw.WriteMinKey()
}
// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey.
func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMaxKey {
return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
}
return vw.WriteMaxKey()
}
// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document.
func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreDocument {
return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
}
cdoc := val.Interface().(bsoncore.Document)
return copyDocumentFromBytes(vw, cdoc)
}
// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope.
func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCodeWithScope {
return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
}
cws := val.Interface().(CodeWithScope)
dw, err := vw.WriteCodeWithScope(string(cws.Code))
if err != nil {
return err
}
sw := sliceWriterPool.Get().(*sliceWriter)
defer sliceWriterPool.Put(sw)
*sw = (*sw)[:0]
scopeVW := bvwPool.Get().(*valueWriter)
scopeVW.reset(scopeVW.buf[:0])
scopeVW.w = sw
defer bvwPool.Put(scopeVW)
encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope))
if err != nil {
return err
}
err = copyBytesToDocumentWriter(dw, *sw)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
func isImplementationNil(val reflect.Value, inter reflect.Type) bool {
vt := val.Type()
for vt.Kind() == reflect.Ptr {
vt = vt.Elem()
}
return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil()
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
)
// emptyInterfaceCodec is the Codec used for interface{} values.
type emptyInterfaceCodec struct {
// decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the
// "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary.
decodeBinaryAsSlice bool
}
// Assert that emptyInterfaceCodec satisfies the typeDecoder interface, which allows it
// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a
// collection.
var _ typeDecoder = &emptyInterfaceCodec{}
// EncodeValue is the ValueEncoderFunc for interface{}.
func (eic *emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tEmpty {
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(val.Elem().Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, val.Elem())
}
func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) {
isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument
if isDocument {
if dc.defaultDocumentType != nil {
// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return
// that type.
return dc.defaultDocumentType, nil
}
}
rtype, err := dc.LookupTypeMapEntry(valueType)
if err == nil {
return rtype, nil
}
if isDocument {
// For documents, fallback to looking up a type map entry for Type(0) or TypeEmbeddedDocument,
// depending on the original valueType.
var lookupType Type
switch valueType {
case Type(0):
lookupType = TypeEmbeddedDocument
case TypeEmbeddedDocument:
lookupType = Type(0)
}
rtype, err = dc.LookupTypeMapEntry(lookupType)
if err == nil {
return rtype, nil
}
// fallback to bson.D
return tD, nil
}
return nil, err
}
func (eic *emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tEmpty {
return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)}
}
rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type())
if err != nil {
switch vr.Type() {
case TypeNull:
return reflect.Zero(t), vr.ReadNull()
default:
return emptyValue, err
}
}
decoder, err := dc.LookupDecoder(rtype)
if err != nil {
return emptyValue, err
}
elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, rtype)
if err != nil {
return emptyValue, err
}
if (eic.decodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary {
binElem := elem.Interface().(Binary)
if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld {
elem = reflect.ValueOf(binElem.Data)
}
}
return elem, nil
}
// DecodeValue is the ValueDecoderFunc for interface{}.
func (eic *emptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tEmpty {
return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
elem, err := eic.decodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.Set(elem)
return nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"sync"
)
// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Encoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var encPool = sync.Pool{
New: func() interface{} {
return new(Encoder)
},
}
// An Encoder writes a serialization format to an output stream. It writes to a ValueWriter
// as the destination of BSON data.
type Encoder struct {
ec EncodeContext
vw ValueWriter
}
// NewEncoder returns a new encoder that writes to vw.
func NewEncoder(vw ValueWriter) *Encoder {
return &Encoder{
ec: EncodeContext{Registry: defaultRegistry},
vw: vw,
}
}
// Encode writes the BSON encoding of val to the stream.
//
// See [Marshal] for details about BSON marshaling behavior.
func (e *Encoder) Encode(val interface{}) error {
if marshaler, ok := val.(Marshaler); ok {
// TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse?
buf, err := marshaler.MarshalBSON()
if err != nil {
return err
}
return copyDocumentFromBytes(e.vw, buf)
}
encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val))
if err != nil {
return err
}
return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val))
}
// Reset will reset the state of the Encoder, using the same *EncodeContext used in
// the original construction but using vw.
func (e *Encoder) Reset(vw ValueWriter) {
e.vw = vw
}
// SetRegistry replaces the current registry of the Encoder with r.
func (e *Encoder) SetRegistry(r *Registry) {
e.ec.Registry = r
}
// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in
// the marshaled BSON when the "inline" struct tag option is set.
func (e *Encoder) ErrorOnInlineDuplicates() {
e.ec.errorOnInlineDuplicates = true
}
// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint,
// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can
// represent the integer value.
func (e *Encoder) IntMinSize() {
e.ec.minSize = true
}
// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name
// strings using fmt.Sprint instead of the default string conversion logic.
func (e *Encoder) StringifyMapKeysWithFmt() {
e.ec.stringifyMapKeysWithFmt = true
}
// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON
// null.
func (e *Encoder) NilMapAsEmpty() {
e.ec.nilMapAsEmpty = true
}
// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON
// null.
func (e *Encoder) NilSliceAsEmpty() {
e.ec.nilSliceAsEmpty = true
}
// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values
// instead of BSON null.
func (e *Encoder) NilByteSliceAsEmpty() {
e.ec.nilByteSliceAsEmpty = true
}
// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported
// TODO struct fields once the logic is updated to also inspect private struct fields.
// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{})
// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set
// or the OmitEmpty() method is called.
//
// Note that the Encoder only examines exported struct fields when determining if a struct is the
// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty.
func (e *Encoder) OmitZeroStruct() {
e.ec.omitZeroStruct = true
}
// OmitEmpty causes the Encoder to omit empty values from the marshaled BSON as the "omitempty"
// struct tag option is set.
func (e *Encoder) OmitEmpty() {
e.ec.omitEmpty = true
}
// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson"
// struct tag is not specified.
func (e *Encoder) UseJSONStructTags() {
e.ec.useJSONStructTags = true
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
const maxNestingDepth = 200
// ErrInvalidJSON indicates the JSON input is invalid
var ErrInvalidJSON = errors.New("invalid JSON input")
type jsonParseState byte
const (
jpsStartState jsonParseState = iota
jpsSawBeginObject
jpsSawEndObject
jpsSawBeginArray
jpsSawEndArray
jpsSawColon
jpsSawComma
jpsSawKey
jpsSawValue
jpsDoneState
jpsInvalidState
)
type jsonParseMode byte
const (
jpmInvalidMode jsonParseMode = iota
jpmObjectMode
jpmArrayMode
)
type extJSONValue struct {
t Type
v interface{}
}
type extJSONObject struct {
keys []string
values []*extJSONValue
}
type extJSONParser struct {
js *jsonScanner
s jsonParseState
m []jsonParseMode
k string
v *extJSONValue
err error
canonicalOnly bool
depth int
maxDepth int
emptyObject bool
relaxedUUID bool
}
// newExtJSONParser returns a new extended JSON parser, ready to to begin
// parsing from the first character of the argued json input. It will not
// perform any read-ahead and will therefore not report any errors about
// malformed JSON at this point.
func newExtJSONParser(r io.Reader, canonicalOnly bool) *extJSONParser {
return &extJSONParser{
js: &jsonScanner{r: r},
s: jpsStartState,
m: []jsonParseMode{},
canonicalOnly: canonicalOnly,
maxDepth: maxNestingDepth,
}
}
// peekType examines the next value and returns its BSON Type
func (ejp *extJSONParser) peekType() (Type, error) {
var t Type
var err error
initialState := ejp.s
ejp.advanceState()
switch ejp.s {
case jpsSawValue:
t = ejp.v.t
case jpsSawBeginArray:
t = TypeArray
case jpsInvalidState:
err = ejp.err
case jpsSawComma:
// in array mode, seeing a comma means we need to progress again to actually observe a type
if ejp.peekMode() == jpmArrayMode {
return ejp.peekType()
}
case jpsSawEndArray:
// this would only be a valid state if we were in array mode, so return end-of-array error
err = ErrEOA
case jpsSawBeginObject:
// peek key to determine type
ejp.advanceState()
switch ejp.s {
case jpsSawEndObject: // empty embedded document
t = TypeEmbeddedDocument
ejp.emptyObject = true
case jpsInvalidState:
err = ejp.err
case jpsSawKey:
if initialState == jpsStartState {
return TypeEmbeddedDocument, nil
}
t = wrapperKeyBSONType(ejp.k)
// if $uuid is encountered, parse as binary subtype 4
if ejp.k == "$uuid" {
ejp.relaxedUUID = true
t = TypeBinary
}
switch t {
case TypeJavaScript:
// just saw $code, need to check for $scope at same level
_, err = ejp.readValue(TypeJavaScript)
if err != nil {
break
}
switch ejp.s {
case jpsSawEndObject: // type is TypeJavaScript
case jpsSawComma:
ejp.advanceState()
if ejp.s == jpsSawKey && ejp.k == "$scope" {
t = TypeCodeWithScope
} else {
err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
}
case jpsInvalidState:
err = ejp.err
default:
err = ErrInvalidJSON
}
case TypeCodeWithScope:
err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope")
}
}
}
return t, err
}
// readKey parses the next key and its type and returns them
func (ejp *extJSONParser) readKey() (string, Type, error) {
if ejp.emptyObject {
ejp.emptyObject = false
return "", 0, ErrEOD
}
// advance to key (or return with error)
switch ejp.s {
case jpsStartState:
ejp.advanceState()
if ejp.s == jpsSawBeginObject {
ejp.advanceState()
}
case jpsSawBeginObject:
ejp.advanceState()
case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject, jpsSawComma:
ejp.advanceState()
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsDoneState:
return "", 0, io.EOF
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, ErrInvalidJSON
}
case jpsSawKey: // do nothing (key was peeked before)
default:
return "", 0, invalidRequestError("key")
}
// read key
var key string
switch ejp.s {
case jpsSawKey:
key = ejp.k
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, invalidRequestError("key")
}
// check for colon
ejp.advanceState()
if err := ensureColon(ejp.s, key); err != nil {
return "", 0, err
}
// peek at the value to determine type
t, err := ejp.peekType()
if err != nil {
return "", 0, err
}
return key, t, nil
}
// readValue returns the value corresponding to the Type returned by peekType
func (ejp *extJSONParser) readValue(t Type) (*extJSONValue, error) {
if ejp.s == jpsInvalidState {
return nil, ejp.err
}
var v *extJSONValue
switch t {
case TypeNull, TypeBoolean, TypeString:
if ejp.s != jpsSawValue {
return nil, invalidRequestError(t.String())
}
v = ejp.v
case TypeInt32, TypeInt64, TypeDouble:
// relaxed version allows these to be literal number values
if ejp.s == jpsSawValue {
v = ejp.v
break
}
fallthrough
case TypeDecimal128, TypeSymbol, TypeObjectID, TypeMinKey, TypeMaxKey, TypeUndefined:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("} after value", t)
}
default:
return nil, invalidRequestError(t.String())
}
case TypeBinary, TypeRegex, TypeTimestamp, TypeDBPointer:
if ejp.s != jpsSawKey {
return nil, invalidRequestError(t.String())
}
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
if t == TypeBinary && ejp.s == jpsSawValue {
// convert relaxed $uuid format
if ejp.relaxedUUID {
defer func() { ejp.relaxedUUID = false }()
uuid, err := ejp.v.parseSymbol()
if err != nil {
return nil, err
}
// RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing
// in the 8th, 13th, 18th, and 23rd characters.
//
// See https://tools.ietf.org/html/rfc4122#section-3
valid := len(uuid) == 36 &&
string(uuid[8]) == "-" &&
string(uuid[13]) == "-" &&
string(uuid[18]) == "-" &&
string(uuid[23]) == "-"
if !valid {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
}
// remove hyphens
uuidNoHyphens := strings.ReplaceAll(uuid, "-", "")
if len(uuidNoHyphens) != 32 {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
}
// convert hex to bytes
bytes, err := hex.DecodeString(uuidNoHyphens)
if err != nil {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("$uuid and value and then }", TypeBinary)
}
base64 := &extJSONValue{
t: TypeString,
v: base64.StdEncoding.EncodeToString(bytes),
}
subType := &extJSONValue{
t: TypeString,
v: "04",
}
v = &extJSONValue{
t: TypeEmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// convert legacy $binary format
base64 := ejp.v
ejp.advanceState()
if ejp.s != jpsSawComma {
return nil, invalidJSONErrorForType(",", TypeBinary)
}
ejp.advanceState()
key, t, err := ejp.readKey()
if err != nil {
return nil, err
}
if key != "$type" {
return nil, invalidJSONErrorForType("$type", TypeBinary)
}
subType, err := ejp.readValue(t)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", TypeBinary)
}
v = &extJSONValue{
t: TypeEmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// read KV pairs
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONErrorForType("{", t)
}
keys, vals, err := ejp.readObject(2, true)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
}
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case TypeDateTime:
switch ejp.s {
case jpsSawValue:
v = ejp.v
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject:
keys, vals, err := ejp.readObject(1, true)
if err != nil {
return nil, err
}
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case jpsSawValue:
if ejp.canonicalOnly {
return nil, invalidJSONError("{")
}
v = ejp.v
default:
if ejp.canonicalOnly {
return nil, invalidJSONErrorForType("object", t)
}
return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as described in RFC-3339", t)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("value and then }", t)
}
default:
return nil, invalidRequestError(t.String())
}
case TypeJavaScript:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object or comma and just return
ejp.advanceState()
case jpsSawEndObject:
v = ejp.v
default:
return nil, invalidRequestError(t.String())
}
case TypeCodeWithScope:
if ejp.s == jpsSawKey && ejp.k == "$scope" {
v = ejp.v // this is the $code string from earlier
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONError("$scope to be embedded document")
}
} else {
return nil, invalidRequestError(t.String())
}
case TypeEmbeddedDocument, TypeArray:
return nil, invalidRequestError(t.String())
}
return v, nil
}
// readObject is a utility method for reading full objects of known (or expected) size
// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
keys := make([]string, numKeys)
vals := make([]*extJSONValue, numKeys)
if !started {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, nil, invalidJSONError("{")
}
}
for i := 0; i < numKeys; i++ {
key, t, err := ejp.readKey()
if err != nil {
return nil, nil, err
}
switch ejp.s {
case jpsSawKey:
v, err := ejp.readValue(t)
if err != nil {
return nil, nil, err
}
keys[i] = key
vals[i] = v
case jpsSawValue:
keys[i] = key
vals[i] = ejp.v
default:
return nil, nil, invalidJSONError("value")
}
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, nil, invalidJSONError("}")
}
return keys, vals, nil
}
// advanceState reads the next JSON token from the scanner and transitions
// from the current state based on that token's type
func (ejp *extJSONParser) advanceState() {
if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
return
}
jt, err := ejp.js.nextToken()
if err != nil {
ejp.err = err
ejp.s = jpsInvalidState
return
}
valid := ejp.validateToken(jt.t)
if !valid {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
return
}
switch jt.t {
case jttBeginObject:
ejp.s = jpsSawBeginObject
ejp.pushMode(jpmObjectMode)
ejp.depth++
if ejp.depth > ejp.maxDepth {
ejp.err = nestingDepthError(jt.p, ejp.depth)
ejp.s = jpsInvalidState
}
case jttEndObject:
ejp.s = jpsSawEndObject
ejp.depth--
if ejp.popMode() != jpmObjectMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttBeginArray:
ejp.s = jpsSawBeginArray
ejp.pushMode(jpmArrayMode)
case jttEndArray:
ejp.s = jpsSawEndArray
if ejp.popMode() != jpmArrayMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttColon:
ejp.s = jpsSawColon
case jttComma:
ejp.s = jpsSawComma
case jttEOF:
ejp.s = jpsDoneState
if len(ejp.m) != 0 {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttString:
switch ejp.s {
case jpsSawComma:
if ejp.peekMode() == jpmArrayMode {
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
return
}
fallthrough
case jpsSawBeginObject:
ejp.s = jpsSawKey
ejp.k = jt.v.(string)
return
}
fallthrough
default:
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
}
}
var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
jpsStartState: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
jttEOF: true,
},
jpsSawBeginObject: {
jttEndObject: true,
jttString: true,
},
jpsSawEndObject: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawBeginArray: {
jttBeginObject: true,
jttBeginArray: true,
jttEndArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawEndArray: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawColon: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawComma: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawKey: {
jttColon: true,
},
jpsSawValue: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsDoneState: {},
jpsInvalidState: {},
}
func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
switch ejp.s {
case jpsSawEndObject:
// if we are at depth zero and the next token is a '{',
// we can consider it valid only if we are not in array mode.
if jtt == jttBeginObject && ejp.depth == 0 {
return ejp.peekMode() != jpmArrayMode
}
case jpsSawComma:
switch ejp.peekMode() {
// the only valid next token after a comma inside a document is a string (a key)
case jpmObjectMode:
return jtt == jttString
case jpmInvalidMode:
return false
}
}
_, ok := jpsValidTransitionTokens[ejp.s][jtt]
return ok
}
// ensureExtValueType returns true if the current value has the expected
// value type for single-key extended JSON types. For example,
// {"$numberInt": v} v must be TypeString
func (ejp *extJSONParser) ensureExtValueType(t Type) bool {
switch t {
case TypeMinKey, TypeMaxKey:
return ejp.v.t == TypeInt32
case TypeUndefined:
return ejp.v.t == TypeBoolean
case TypeInt32, TypeInt64, TypeDouble, TypeDecimal128, TypeSymbol, TypeObjectID:
return ejp.v.t == TypeString
default:
return false
}
}
func (ejp *extJSONParser) pushMode(m jsonParseMode) {
ejp.m = append(ejp.m, m)
}
func (ejp *extJSONParser) popMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
m := ejp.m[l-1]
ejp.m = ejp.m[:l-1]
return m
}
func (ejp *extJSONParser) peekMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
return ejp.m[l-1]
}
func extendJSONToken(jt *jsonToken) *extJSONValue {
var t Type
switch jt.t {
case jttInt32:
t = TypeInt32
case jttInt64:
t = TypeInt64
case jttDouble:
t = TypeDouble
case jttString:
t = TypeString
case jttBool:
t = TypeBoolean
case jttNull:
t = TypeNull
default:
return nil
}
return &extJSONValue{t: t, v: jt.v}
}
func ensureColon(s jsonParseState, key string) error {
if s != jpsSawColon {
return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
}
return nil
}
func invalidRequestError(s string) error {
return fmt.Errorf("invalid request to read %s", s)
}
func invalidJSONError(expected string) error {
return fmt.Errorf("invalid JSON input; expected %s", expected)
}
func invalidJSONErrorForType(expected string, t Type) error {
return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
}
func unexpectedTokenError(jt *jsonToken) error {
switch jt.t {
case jttInt32, jttInt64, jttDouble:
return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
case jttString:
return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
case jttBool:
return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
case jttNull:
return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
case jttEOF:
return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
default:
return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
}
}
func nestingDepthError(p, depth int) error {
return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"io"
)
type ejvrState struct {
mode mode
vType Type
depth int
}
// extJSONValueReader is for reading extended JSON.
type extJSONValueReader struct {
p *extJSONParser
stack []ejvrState
frame int
}
// NewExtJSONValueReader returns a ValueReader that reads Extended JSON values
// from r. If canonicalOnly is true, reading values from the ValueReader returns
// an error if the Extended JSON was not marshaled in canonical mode.
func NewExtJSONValueReader(r io.Reader, canonicalOnly bool) (ValueReader, error) {
return newExtJSONValueReader(r, canonicalOnly)
}
func newExtJSONValueReader(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
ejvr := new(extJSONValueReader)
return ejvr.reset(r, canonicalOnly)
}
func (ejvr *extJSONValueReader) reset(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
p := newExtJSONParser(r, canonicalOnly)
typ, err := p.peekType()
if err != nil {
return nil, ErrInvalidJSON
}
var m mode
switch typ {
case TypeEmbeddedDocument:
m = mTopLevel
case TypeArray:
m = mArray
default:
m = mValue
}
stack := make([]ejvrState, 1, 5)
stack[0] = ejvrState{
mode: m,
vType: typ,
}
return &extJSONValueReader{
p: p,
stack: stack,
}, nil
}
func (ejvr *extJSONValueReader) advanceFrame() {
if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack
length := len(ejvr.stack)
if length+1 >= cap(ejvr.stack) {
// double it
buf := make([]ejvrState, 2*cap(ejvr.stack)+1)
copy(buf, ejvr.stack)
ejvr.stack = buf
}
ejvr.stack = ejvr.stack[:length+1]
}
ejvr.frame++
// Clean the stack
ejvr.stack[ejvr.frame].mode = 0
ejvr.stack[ejvr.frame].vType = 0
ejvr.stack[ejvr.frame].depth = 0
}
func (ejvr *extJSONValueReader) pushDocument() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mDocument
ejvr.stack[ejvr.frame].depth = ejvr.p.depth
}
func (ejvr *extJSONValueReader) pushCodeWithScope() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mCodeWithScope
}
func (ejvr *extJSONValueReader) pushArray() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mArray
}
func (ejvr *extJSONValueReader) push(m mode, t Type) {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = m
ejvr.stack[ejvr.frame].vType = t
}
func (ejvr *extJSONValueReader) pop() {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
ejvr.frame--
case mDocument, mArray, mCodeWithScope:
ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (ejvr *extJSONValueReader) skipObject() {
// read entire object until depth returns to 0 (last ending } or ] seen)
depth := 1
for depth > 0 {
ejvr.p.advanceState()
// If object is empty, raise depth and continue. When emptyObject is true, the
// parser has already read both the opening and closing brackets of an empty
// object ("{}"), so the next valid token will be part of the parent document,
// not part of the nested document.
//
// If there is a comma, there are remaining fields, emptyObject must be set back
// to false, and comma must be skipped with advanceState().
if ejvr.p.emptyObject {
if ejvr.p.s == jpsSawComma {
ejvr.p.emptyObject = false
ejvr.p.advanceState()
}
depth--
continue
}
switch ejvr.p.s {
case jpsSawBeginObject, jpsSawBeginArray:
depth++
case jpsSawEndObject, jpsSawEndArray:
depth--
}
}
}
func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvr.stack[ejvr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if ejvr.frame != 0 {
te.parent = ejvr.stack[ejvr.frame-1].mode
}
return te
}
func (ejvr *extJSONValueReader) typeError(t Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t)
}
func (ejvr *extJSONValueReader) ensureElementValue(t Type, destination mode, callerName string, addModes ...mode) error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != t {
return ejvr.typeError(t)
}
default:
modes := []mode{mElement, mValue}
if addModes != nil {
modes = append(modes, addModes...)
}
return ejvr.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvr *extJSONValueReader) Type() Type {
return ejvr.stack[ejvr.frame].vType
}
func (ejvr *extJSONValueReader) Skip() error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
default:
return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
defer ejvr.pop()
t := ejvr.stack[ejvr.frame].vType
switch t {
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
// read entire array, doc or CodeWithScope
ejvr.skipObject()
default:
_, err := ejvr.p.readValue(t)
if err != nil {
return err
}
}
return nil
}
func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel: // allow reading array from top level
case mArray:
return ejvr, nil
default:
if err := ejvr.ensureElementValue(TypeArray, mArray, "ReadArray", mTopLevel, mArray); err != nil {
return nil, err
}
}
ejvr.pushArray()
return ejvr, nil
}
func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := ejvr.ensureElementValue(TypeBinary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
v, err := ejvr.p.readValue(TypeBinary)
if err != nil {
return nil, 0, err
}
b, btype, err = v.parseBinary()
ejvr.pop()
return b, btype, err
}
func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) {
if err := ejvr.ensureElementValue(TypeBoolean, 0, "ReadBoolean"); err != nil {
return false, err
}
v, err := ejvr.p.readValue(TypeBoolean)
if err != nil {
return false, err
}
if v.t != TypeBoolean {
return false, fmt.Errorf("expected type bool, but got type %s", v.t)
}
ejvr.pop()
return v.v.(bool), nil
}
func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel:
return ejvr, nil
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != TypeEmbeddedDocument {
return nil, ejvr.typeError(TypeEmbeddedDocument)
}
ejvr.pushDocument()
return ejvr, nil
default:
return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
}
func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err = ejvr.ensureElementValue(TypeCodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
v, err := ejvr.p.readValue(TypeCodeWithScope)
if err != nil {
return "", nil, err
}
code, err = v.parseJavascript()
ejvr.pushCodeWithScope()
return code, ejvr, err
}
func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid ObjectID, err error) {
if err = ejvr.ensureElementValue(TypeDBPointer, 0, "ReadDBPointer"); err != nil {
return "", NilObjectID, err
}
v, err := ejvr.p.readValue(TypeDBPointer)
if err != nil {
return "", NilObjectID, err
}
ns, oid, err = v.parseDBPointer()
ejvr.pop()
return ns, oid, err
}
func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) {
if err := ejvr.ensureElementValue(TypeDateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeDateTime)
if err != nil {
return 0, err
}
d, err := v.parseDateTime()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDecimal128() (Decimal128, error) {
if err := ejvr.ensureElementValue(TypeDecimal128, 0, "ReadDecimal128"); err != nil {
return Decimal128{}, err
}
v, err := ejvr.p.readValue(TypeDecimal128)
if err != nil {
return Decimal128{}, err
}
d, err := v.parseDecimal128()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDouble() (float64, error) {
if err := ejvr.ensureElementValue(TypeDouble, 0, "ReadDouble"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeDouble)
if err != nil {
return 0, err
}
d, err := v.parseDouble()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadInt32() (int32, error) {
if err := ejvr.ensureElementValue(TypeInt32, 0, "ReadInt32"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeInt32)
if err != nil {
return 0, err
}
i, err := v.parseInt32()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadInt64() (int64, error) {
if err := ejvr.ensureElementValue(TypeInt64, 0, "ReadInt64"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeInt64)
if err != nil {
return 0, err
}
i, err := v.parseInt64()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) {
if err = ejvr.ensureElementValue(TypeJavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeJavaScript)
if err != nil {
return "", err
}
code, err = v.parseJavascript()
ejvr.pop()
return code, err
}
func (ejvr *extJSONValueReader) ReadMaxKey() error {
if err := ejvr.ensureElementValue(TypeMaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeMaxKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("max")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadMinKey() error {
if err := ejvr.ensureElementValue(TypeMinKey, 0, "ReadMinKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeMinKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("min")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadNull() error {
if err := ejvr.ensureElementValue(TypeNull, 0, "ReadNull"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeNull)
if err != nil {
return err
}
if v.t != TypeNull {
return fmt.Errorf("expected type null but got type %s", v.t)
}
ejvr.pop()
return nil
}
func (ejvr *extJSONValueReader) ReadObjectID() (ObjectID, error) {
if err := ejvr.ensureElementValue(TypeObjectID, 0, "ReadObjectID"); err != nil {
return ObjectID{}, err
}
v, err := ejvr.p.readValue(TypeObjectID)
if err != nil {
return ObjectID{}, err
}
oid, err := v.parseObjectID()
ejvr.pop()
return oid, err
}
func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) {
if err = ejvr.ensureElementValue(TypeRegex, 0, "ReadRegex"); err != nil {
return "", "", err
}
v, err := ejvr.p.readValue(TypeRegex)
if err != nil {
return "", "", err
}
pattern, options, err = v.parseRegex()
ejvr.pop()
return pattern, options, err
}
func (ejvr *extJSONValueReader) ReadString() (string, error) {
if err := ejvr.ensureElementValue(TypeString, 0, "ReadString"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeString)
if err != nil {
return "", err
}
if v.t != TypeString {
return "", fmt.Errorf("expected type string but got type %s", v.t)
}
ejvr.pop()
return v.v.(string), nil
}
func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) {
if err = ejvr.ensureElementValue(TypeSymbol, 0, "ReadSymbol"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeSymbol)
if err != nil {
return "", err
}
symbol, err = v.parseSymbol()
ejvr.pop()
return symbol, err
}
func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err = ejvr.ensureElementValue(TypeTimestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
v, err := ejvr.p.readValue(TypeTimestamp)
if err != nil {
return 0, 0, err
}
t, i, err = v.parseTimestamp()
ejvr.pop()
return t, i, err
}
func (ejvr *extJSONValueReader) ReadUndefined() error {
if err := ejvr.ensureElementValue(TypeUndefined, 0, "ReadUndefined"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeUndefined)
if err != nil {
return err
}
err = v.parseUndefined()
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
name, t, err := ejvr.p.readKey()
if err != nil {
if errors.Is(err, ErrEOD) {
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
_, err := ejvr.p.peekType()
if err != nil {
return "", nil, err
}
}
ejvr.pop()
}
return "", nil, err
}
ejvr.push(mElement, t)
return name, ejvr, nil
}
func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mArray:
default:
return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := ejvr.p.peekType()
if err != nil {
if errors.Is(err, ErrEOA) {
ejvr.pop()
}
return nil, err
}
ejvr.push(mValue, t)
return ejvr, nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/base64"
"errors"
"fmt"
"math"
"strconv"
"time"
)
func wrapperKeyBSONType(key string) Type {
switch key {
case "$numberInt":
return TypeInt32
case "$numberLong":
return TypeInt64
case "$oid":
return TypeObjectID
case "$symbol":
return TypeSymbol
case "$numberDouble":
return TypeDouble
case "$numberDecimal":
return TypeDecimal128
case "$binary":
return TypeBinary
case "$code":
return TypeJavaScript
case "$scope":
return TypeCodeWithScope
case "$timestamp":
return TypeTimestamp
case "$regularExpression":
return TypeRegex
case "$dbPointer":
return TypeDBPointer
case "$date":
return TypeDateTime
case "$minKey":
return TypeMinKey
case "$maxKey":
return TypeMaxKey
case "$undefined":
return TypeUndefined
}
return TypeEmbeddedDocument
}
func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
if ejv.t != TypeEmbeddedDocument {
return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
}
binObj := ejv.v.(*extJSONObject)
bFound := false
stFound := false
for i, key := range binObj.keys {
val := binObj.values[i]
switch key {
case "base64":
if bFound {
return nil, 0, errors.New("duplicate base64 key in $binary")
}
if val.t != TypeString {
return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
}
base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
}
b = base64Bytes
bFound = true
case "subType":
if stFound {
return nil, 0, errors.New("duplicate subType key in $binary")
}
if val.t != TypeString {
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
}
i, err := strconv.ParseUint(val.v.(string), 16, 8)
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err)
}
subType = byte(i)
stFound = true
default:
return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
}
}
if !bFound {
return nil, 0, errors.New("missing base64 field in $binary object")
}
if !stFound {
return nil, 0, errors.New("missing subType field in $binary object")
}
return b, subType, nil
}
func (ejv *extJSONValue) parseDBPointer() (ns string, oid ObjectID, err error) {
if ejv.t != TypeEmbeddedDocument {
return "", NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
}
dbpObj := ejv.v.(*extJSONObject)
oidFound := false
nsFound := false
for i, key := range dbpObj.keys {
val := dbpObj.values[i]
switch key {
case "$ref":
if nsFound {
return "", NilObjectID, errors.New("duplicate $ref key in $dbPointer")
}
if val.t != TypeString {
return "", NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
}
ns = val.v.(string)
nsFound = true
case "$id":
if oidFound {
return "", NilObjectID, errors.New("duplicate $id key in $dbPointer")
}
if val.t != TypeString {
return "", NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
}
oid, err = ObjectIDFromHex(val.v.(string))
if err != nil {
return "", NilObjectID, err
}
oidFound = true
default:
return "", NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
}
}
if !nsFound {
return "", oid, errors.New("missing $ref field in $dbPointer object")
}
if !oidFound {
return "", oid, errors.New("missing $id field in $dbPointer object")
}
return ns, oid, nil
}
const (
rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
)
var (
timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"}
)
func (ejv *extJSONValue) parseDateTime() (int64, error) {
switch ejv.t {
case TypeInt32:
return int64(ejv.v.(int32)), nil
case TypeInt64:
return ejv.v.(int64), nil
case TypeString:
return parseDatetimeString(ejv.v.(string))
case TypeEmbeddedDocument:
return parseDatetimeObject(ejv.v.(*extJSONObject))
default:
return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
}
}
func parseDatetimeString(data string) (int64, error) {
var t time.Time
var err error
// try acceptable time formats until one matches
for _, format := range timeFormats {
t, err = time.Parse(format, data)
if err == nil {
break
}
}
if err != nil {
return 0, fmt.Errorf("invalid $date value string: %s", data)
}
return int64(NewDateTimeFromTime(t)), nil
}
func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
dFound := false
for i, key := range data.keys {
val := data.values[i]
switch key {
case "$numberLong":
if dFound {
return 0, errors.New("duplicate $numberLong key in $date")
}
if val.t != TypeString {
return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
}
d, err = val.parseInt64()
if err != nil {
return 0, err
}
dFound = true
default:
return 0, fmt.Errorf("invalid key in $date object: %s", key)
}
}
if !dFound {
return 0, errors.New("missing $numberLong field in $date object")
}
return d, nil
}
func (ejv *extJSONValue) parseDecimal128() (Decimal128, error) {
if ejv.t != TypeString {
return Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
}
d, err := ParseDecimal128(ejv.v.(string))
if err != nil {
return Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
}
return d, nil
}
func (ejv *extJSONValue) parseDouble() (float64, error) {
if ejv.t == TypeDouble {
return ejv.v.(float64), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
}
switch ejv.v.(string) {
case "Infinity":
return math.Inf(1), nil
case "-Infinity":
return math.Inf(-1), nil
case "NaN":
return math.NaN(), nil
}
f, err := strconv.ParseFloat(ejv.v.(string), 64)
if err != nil {
return 0, err
}
return f, nil
}
func (ejv *extJSONValue) parseInt32() (int32, error) {
if ejv.t == TypeInt32 {
return ejv.v.(int32), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
if i < math.MinInt32 || i > math.MaxInt32 {
return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
}
return int32(i), nil
}
func (ejv *extJSONValue) parseInt64() (int64, error) {
if ejv.t == TypeInt64 {
return ejv.v.(int64), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
return i, nil
}
func (ejv *extJSONValue) parseJavascript() (code string, err error) {
if ejv.t != TypeString {
return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
if ejv.t != TypeInt32 {
return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
}
if ejv.v.(int32) != 1 {
return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
}
return nil
}
func (ejv *extJSONValue) parseObjectID() (ObjectID, error) {
if ejv.t != TypeString {
return NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
}
return ObjectIDFromHex(ejv.v.(string))
}
func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
if ejv.t != TypeEmbeddedDocument {
return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
}
regexObj := ejv.v.(*extJSONObject)
patFound := false
optFound := false
for i, key := range regexObj.keys {
val := regexObj.values[i]
switch key {
case "pattern":
if patFound {
return "", "", errors.New("duplicate pattern key in $regularExpression")
}
if val.t != TypeString {
return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
}
pattern = val.v.(string)
patFound = true
case "options":
if optFound {
return "", "", errors.New("duplicate options key in $regularExpression")
}
if val.t != TypeString {
return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
}
options = val.v.(string)
optFound = true
default:
return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
}
}
if !patFound {
return "", "", errors.New("missing pattern field in $regularExpression object")
}
if !optFound {
return "", "", errors.New("missing options field in $regularExpression object")
}
return pattern, options, nil
}
func (ejv *extJSONValue) parseSymbol() (string, error) {
if ejv.t != TypeString {
return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
if ejv.t != TypeEmbeddedDocument {
return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
}
handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
if flag {
return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
}
switch val.t {
case TypeInt32:
value := val.v.(int32)
if value < 0 {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
}
return uint32(value), nil
case TypeInt64:
value := val.v.(int64)
if value < 0 || value > int64(math.MaxUint32) {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
}
return uint32(value), nil
default:
return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
}
}
tsObj := ejv.v.(*extJSONObject)
tFound := false
iFound := false
for j, key := range tsObj.keys {
val := tsObj.values[j]
switch key {
case "t":
if t, err = handleKey(key, val, tFound); err != nil {
return 0, 0, err
}
tFound = true
case "i":
if i, err = handleKey(key, val, iFound); err != nil {
return 0, 0, err
}
iFound = true
default:
return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
}
}
if !tFound {
return 0, 0, errors.New("missing t field in $timestamp object")
}
if !iFound {
return 0, 0, errors.New("missing i field in $timestamp object")
}
return t, i, nil
}
func (ejv *extJSONValue) parseUndefined() error {
if ejv.t != TypeBoolean {
return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
}
if !ejv.v.(bool) {
return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
}
return nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"math"
"sort"
"strconv"
"strings"
"time"
"unicode/utf8"
)
type ejvwState struct {
mode mode
}
type extJSONValueWriter struct {
w io.Writer
buf []byte
stack []ejvwState
frame int64
canonical bool
escapeHTML bool
newlines bool
}
// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) ValueWriter {
// Enable newlines for all Extended JSON value writers created by NewExtJSONValueWriter. We
// expect these value writers to be used with an Encoder, which should add newlines after
// encoded Extended JSON documents.
return newExtJSONWriter(w, canonical, escapeHTML, true)
}
func newExtJSONWriter(w io.Writer, canonical, escapeHTML, newlines bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
w: w,
buf: []byte{},
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
newlines: newlines,
}
}
func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
buf: buf,
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
if ejvw.stack == nil {
ejvw.stack = make([]ejvwState, 1, 5)
}
ejvw.stack = ejvw.stack[:1]
ejvw.stack[0] = ejvwState{mode: mTopLevel}
ejvw.canonical = canonical
ejvw.escapeHTML = escapeHTML
ejvw.frame = 0
ejvw.buf = buf
ejvw.w = nil
}
func (ejvw *extJSONValueWriter) advanceFrame() {
if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
length := len(ejvw.stack)
if length+1 >= cap(ejvw.stack) {
// double it
buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
copy(buf, ejvw.stack)
ejvw.stack = buf
}
ejvw.stack = ejvw.stack[:length+1]
}
ejvw.frame++
}
func (ejvw *extJSONValueWriter) push(m mode) {
ejvw.advanceFrame()
ejvw.stack[ejvw.frame].mode = m
}
func (ejvw *extJSONValueWriter) pop() {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
ejvw.frame--
case mDocument, mArray, mCodeWithScope:
ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvw.stack[ejvw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if ejvw.frame != 0 {
te.parent = ejvw.stack[ejvw.frame-1].mode
}
return te
}
func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return ejvw.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
var s string
if quotes {
s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
} else {
s = fmt.Sprintf(`{"$%s":%s}`, key, value)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '[')
ejvw.push(mArray)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
return ejvw.WriteBinaryWithSubtype(b, 0x00)
}
func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$binary":{"base64":"`)
buf.WriteString(base64.StdEncoding.EncodeToString(b))
buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
var buf bytes.Buffer
buf.WriteString(`{"$code":`)
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
buf.WriteString(`,"$scope":{`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.push(mCodeWithScope)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$dbPointer":{"$ref":"`)
buf.WriteString(ns)
buf.WriteString(`","$id":{"$oid":"`)
buf.WriteString(oid.Hex())
buf.WriteString(`"}}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
return err
}
t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
ejvw.writeExtendedSingleValue("date", s, false)
} else {
ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDecimal128(d Decimal128) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
if ejvw.stack[ejvw.frame].mode == mTopLevel {
ejvw.buf = append(ejvw.buf, '{')
return ejvw, nil
}
if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '{')
ejvw.push(mDocument)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
return err
}
s := formatDouble(f)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberDouble", s, true)
} else {
switch s {
case "Infinity":
fallthrough
case "-Infinity":
fallthrough
case "NaN":
s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
return err
}
s := strconv.FormatInt(int64(i), 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberInt", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
return err
}
s := strconv.FormatInt(i, 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberLong", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("code", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMaxKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("maxKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMinKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("minKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteNull() error {
if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte("null")...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteObjectID(oid ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
return err
}
options = sortStringAlphebeticAscending(options)
var buf bytes.Buffer
buf.WriteString(`{"$regularExpression":{"pattern":`)
writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
buf.WriteString(`,"options":`)
writeStringWithEscapes(options, &buf, ejvw.escapeHTML)
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteString(s string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$timestamp":{"t":`)
buf.WriteString(strconv.FormatUint(uint64(t), 10))
buf.WriteString(`,"i":`)
buf.WriteString(strconv.FormatUint(uint64(i), 10))
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteUndefined() error {
if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("undefined", "true", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
var buf bytes.Buffer
writeStringWithEscapes(key, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...)
ejvw.push(mElement)
default:
return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
default:
return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
}
// close the document
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = '}'
} else {
ejvw.buf = append(ejvw.buf, '}')
}
switch ejvw.stack[ejvw.frame].mode {
case mCodeWithScope:
ejvw.buf = append(ejvw.buf, '}')
fallthrough
case mDocument:
ejvw.buf = append(ejvw.buf, ',')
case mTopLevel:
// If the value writer has newlines enabled, end top-level documents with a newline so that
// multiple documents encoded to the same writer are separated by newlines. That matches the
// Go json.Encoder behavior and also works with NewExtJSONValueReader.
if ejvw.newlines {
ejvw.buf = append(ejvw.buf, '\n')
}
if ejvw.w != nil {
if _, err := ejvw.w.Write(ejvw.buf); err != nil {
return err
}
ejvw.buf = ejvw.buf[:0]
}
}
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
ejvw.push(mValue)
default:
return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
// close the array
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = ']'
} else {
ejvw.buf = append(ejvw.buf, ']')
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
default:
return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
}
return nil
}
func formatDouble(f float64) string {
var s string
switch {
case math.IsInf(f, 1):
s = "Infinity"
case math.IsInf(f, -1):
s = "-Infinity"
case math.IsNaN(f):
s = "NaN"
default:
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
var hexChars = "0123456789abcdef"
func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
case '\b':
buf.WriteByte('\\')
buf.WriteByte('b')
case '\f':
buf.WriteByte('\\')
buf.WriteByte('f')
default:
// This encodes bytes < 0x20 except for \t, \n and \r.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hexChars[b>>4])
buf.WriteByte(hexChars[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hexChars[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
}
type sortableString []rune
func (ss sortableString) Len() int {
return len(ss)
}
func (ss sortableString) Less(i, j int) bool {
return ss[i] < ss[j]
}
func (ss sortableString) Swap(i, j int) {
ss[i], ss[j] = ss[j], ss[i]
}
func sortStringAlphebeticAscending(s string) string {
ss := sortableString([]rune(s))
sort.Sort(ss)
return string([]rune(ss))
}
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"testing"
)
func FuzzDecode(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
for _, typ := range []func() interface{}{
func() interface{} { return new(D) },
func() interface{} { return new([]E) },
func() interface{} { return new(M) },
func() interface{} { return new(interface{}) },
func() interface{} { return make(map[string]interface{}) },
func() interface{} { return new([]interface{}) },
} {
i := typ()
if err := Unmarshal(data, i); err != nil {
return
}
encoded, err := Marshal(i)
if err != nil {
t.Fatal("failed to marshal", err)
}
if err := Unmarshal(encoded, i); err != nil {
t.Fatal("failed to unmarshal", err)
}
}
})
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"strconv"
"unicode"
"unicode/utf16"
)
type jsonTokenType byte
const (
jttBeginObject jsonTokenType = iota
jttEndObject
jttBeginArray
jttEndArray
jttColon
jttComma
jttInt32
jttInt64
jttDouble
jttString
jttBool
jttNull
jttEOF
)
type jsonToken struct {
t jsonTokenType
v interface{}
p int
}
type jsonScanner struct {
r io.Reader
buf []byte
pos int
lastReadErr error
}
// nextToken returns the next JSON token if one exists. A token is a character
// of the JSON grammar, a number, a string, or a literal.
func (js *jsonScanner) nextToken() (*jsonToken, error) {
c, err := js.readNextByte()
// keep reading until a non-space is encountered (break on read error or EOF)
for isWhiteSpace(c) && err == nil {
c, err = js.readNextByte()
}
if errors.Is(err, io.EOF) {
return &jsonToken{t: jttEOF}, nil
} else if err != nil {
return nil, err
}
// switch on the character
switch c {
case '{':
return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil
case '}':
return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil
case '[':
return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil
case ']':
return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil
case ':':
return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil
case ',':
return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil
case '"': // RFC-8259 only allows for double quotes (") not single (')
return js.scanString()
default:
// check if it's a number
switch {
case c == '-' || isDigit(c):
return js.scanNumber(c)
case c == 't' || c == 'f' || c == 'n':
// maybe a literal
return js.scanLiteral(c)
default:
return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c)
}
}
}
// readNextByte attempts to read the next byte from the buffer. If the buffer
// has been exhausted, this function calls readIntoBuf, thus refilling the
// buffer and resetting the read position to 0
func (js *jsonScanner) readNextByte() (byte, error) {
if js.pos >= len(js.buf) {
err := js.readIntoBuf()
if err != nil {
return 0, err
}
}
b := js.buf[js.pos]
js.pos++
return b, nil
}
// readNNextBytes reads n bytes into dst, starting at offset
func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error {
var err error
for i := 0; i < n; i++ {
dst[i+offset], err = js.readNextByte()
if err != nil {
return err
}
}
return nil
}
// readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer
func (js *jsonScanner) readIntoBuf() error {
if js.lastReadErr != nil {
js.buf = js.buf[:0]
js.pos = 0
return js.lastReadErr
}
if cap(js.buf) == 0 {
js.buf = make([]byte, 0, 512)
}
n, err := js.r.Read(js.buf[:cap(js.buf)])
if err != nil {
js.lastReadErr = err
if n > 0 {
err = nil
}
}
js.buf = js.buf[:n]
js.pos = 0
return err
}
func isWhiteSpace(c byte) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
}
func isDigit(c byte) bool {
return unicode.IsDigit(rune(c))
}
func isValueTerminator(c byte) bool {
return c == ',' || c == '}' || c == ']' || isWhiteSpace(c)
}
// getu4 decodes the 4-byte hex sequence from the beginning of s, returning the hex value as a rune,
// or it returns -1. Note that the "\u" from the unicode escape sequence should not be present.
// It is copied and lightly modified from the Go JSON decode function at
// https://github.com/golang/go/blob/1b0a0316802b8048d69da49dc23c5a5ab08e8ae8/src/encoding/json/decode.go#L1169-L1188
func getu4(s []byte) rune {
if len(s) < 4 {
return -1
}
var r rune
for _, c := range s[:4] {
switch {
case '0' <= c && c <= '9':
c -= '0'
case 'a' <= c && c <= 'f':
c = c - 'a' + 10
case 'A' <= c && c <= 'F':
c = c - 'A' + 10
default:
return -1
}
r = r*16 + rune(c)
}
return r
}
// scanString reads from an opening '"' to a closing '"' and handles escaped characters
func (js *jsonScanner) scanString() (*jsonToken, error) {
var b bytes.Buffer
var c byte
var err error
p := js.pos - 1
for {
c, err = js.readNextByte()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
evalNextChar:
switch c {
case '\\':
c, err = js.readNextByte()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
evalNextEscapeChar:
switch c {
case '"', '\\', '/':
b.WriteByte(c)
case 'b':
b.WriteByte('\b')
case 'f':
b.WriteByte('\f')
case 'n':
b.WriteByte('\n')
case 'r':
b.WriteByte('\r')
case 't':
b.WriteByte('\t')
case 'u':
us := make([]byte, 4)
err = js.readNNextBytes(us, 4, 0)
if err != nil {
return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
}
rn := getu4(us)
// If the rune we just decoded is the high or low value of a possible surrogate pair,
// try to decode the next sequence as the low value of a surrogate pair. We're
// expecting the next sequence to be another Unicode escape sequence (e.g. "\uDD1E"),
// but need to handle cases where the input is not a valid surrogate pair.
// For more context on unicode surrogate pairs, see:
// https://www.christianfscott.com/rust-chars-vs-go-runes/
// https://www.unicode.org/glossary/#high_surrogate_code_point
if utf16.IsSurrogate(rn) {
c, err = js.readNextByte()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
// If the next value isn't the beginning of a backslash escape sequence, write
// the Unicode replacement character for the surrogate value and goto the
// beginning of the next char eval block.
if c != '\\' {
b.WriteRune(unicode.ReplacementChar)
goto evalNextChar
}
c, err = js.readNextByte()
if err != nil {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
// If the next value isn't the beginning of a unicode escape sequence, write the
// Unicode replacement character for the surrogate value and goto the beginning
// of the next escape char eval block.
if c != 'u' {
b.WriteRune(unicode.ReplacementChar)
goto evalNextEscapeChar
}
err = js.readNNextBytes(us, 4, 0)
if err != nil {
return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
}
rn2 := getu4(us)
// Try to decode the pair of runes as a utf16 surrogate pair. If that fails, write
// the Unicode replacement character for the surrogate value and the 2nd decoded rune.
if rnPair := utf16.DecodeRune(rn, rn2); rnPair != unicode.ReplacementChar {
b.WriteRune(rnPair)
} else {
b.WriteRune(unicode.ReplacementChar)
b.WriteRune(rn2)
}
break
}
b.WriteRune(rn)
default:
return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c)
}
case '"':
return &jsonToken{t: jttString, v: b.String(), p: p}, nil
default:
b.WriteByte(c)
}
}
}
// scanLiteral reads an unquoted sequence of characters and determines if it is one of
// three valid JSON literals (true, false, null); if so, it returns the appropriate
// jsonToken; otherwise, it returns an error
func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) {
p := js.pos - 1
lit := make([]byte, 4)
lit[0] = first
err := js.readNNextBytes(lit, 3, 1)
if err != nil {
return nil, err
}
c5, err := js.readNextByte()
switch {
case bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || errors.Is(err, io.EOF)):
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: true, p: p}, nil
case bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || errors.Is(err, io.EOF)):
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttNull, v: nil, p: p}, nil
case bytes.Equal([]byte("fals"), lit):
if c5 == 'e' {
c5, err = js.readNextByte()
if isValueTerminator(c5) || errors.Is(err, io.EOF) {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: false, p: p}, nil
}
}
}
return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit)
}
type numberScanState byte
const (
nssSawLeadingMinus numberScanState = iota
nssSawLeadingZero
nssSawIntegerDigits
nssSawDecimalPoint
nssSawFractionDigits
nssSawExponentLetter
nssSawExponentSign
nssSawExponentDigits
nssDone
nssInvalid
)
// scanNumber reads a JSON number (according to RFC-8259)
func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
var b bytes.Buffer
var s numberScanState
var c byte
var err error
t := jttInt64 // assume it's an int64 until the type can be determined
start := js.pos - 1
b.WriteByte(first)
switch first {
case '-':
s = nssSawLeadingMinus
case '0':
s = nssSawLeadingZero
default:
s = nssSawIntegerDigits
}
for {
c, err = js.readNextByte()
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
switch s {
case nssSawLeadingMinus:
switch c {
case '0':
s = nssSawLeadingZero
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawIntegerDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawLeadingZero:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || errors.Is(err, io.EOF) {
s = nssDone
} else {
s = nssInvalid
}
}
case nssSawIntegerDigits:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
switch {
case isWhiteSpace(c) || errors.Is(err, io.EOF):
s = nssDone
case isDigit(c):
s = nssSawIntegerDigits
b.WriteByte(c)
default:
s = nssInvalid
}
}
case nssSawDecimalPoint:
t = jttDouble
if isDigit(c) {
s = nssSawFractionDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawFractionDigits:
switch c {
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
switch {
case isWhiteSpace(c) || errors.Is(err, io.EOF):
s = nssDone
case isDigit(c):
s = nssSawFractionDigits
b.WriteByte(c)
default:
s = nssInvalid
}
}
case nssSawExponentLetter:
t = jttDouble
switch c {
case '+', '-':
s = nssSawExponentSign
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawExponentSign:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawExponentDigits:
switch c {
case '}', ']', ',':
s = nssDone
default:
switch {
case isWhiteSpace(c) || errors.Is(err, io.EOF):
s = nssDone
case isDigit(c):
s = nssSawExponentDigits
b.WriteByte(c)
default:
s = nssInvalid
}
}
}
switch s {
case nssInvalid:
return nil, fmt.Errorf("invalid JSON number. Position: %d", start)
case nssDone:
js.pos = int(math.Max(0, float64(js.pos-1)))
if t != jttDouble {
v, err := strconv.ParseInt(b.String(), 10, 64)
if err == nil {
if v < math.MinInt32 || v > math.MaxInt32 {
return &jsonToken{t: jttInt64, v: v, p: start}, nil
}
return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil
}
}
v, err := strconv.ParseFloat(b.String(), 64)
if err != nil {
return nil, err
}
return &jsonToken{t: jttDouble, v: v, p: start}, nil
}
}
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding"
"errors"
"fmt"
"reflect"
"strconv"
)
// mapCodec is the Codec used for map values.
type mapCodec struct {
// DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination
// value passed to Decode before unmarshaling BSON documents into them.
decodeZerosMap bool
// EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of
// BSON null.
encodeNilAsEmpty bool
// EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name
// strings using fmt.Sprintf() instead of the default string conversion logic.
encodeKeysWithStringer bool
}
// KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
// This applies to types used as map keys and is similar to encoding.TextMarshaler.
type KeyMarshaler interface {
MarshalKey() (key string, err error)
}
// KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
// of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
//
// UnmarshalKey must be able to decode the form generated by MarshalKey.
// UnmarshalKey must copy the text if it wishes to retain the text
// after returning.
type KeyUnmarshaler interface {
UnmarshalKey(key string) error
}
// EncodeValue is the ValueEncoder for map[*]* types.
func (mc *mapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Map {
return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
if val.IsNil() && !mc.encodeNilAsEmpty && !ec.nilMapAsEmpty {
// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
// to a TopLevel document. We can't currently tell if this is what actually happened, but if
// there's a deeper underlying problem, the error will also be returned from WriteDocument,
// so just continue. The operations on a map reflection value are valid, so we can call
// MapKeys within mapEncodeValue without a problem.
err := vw.WriteNull()
if err == nil {
return nil
}
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
err = mc.encodeMapElements(ec, dw, val, nil)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// encodeMapElements handles encoding of the values of a map. The collisionFn returns
// true if the provided key exists, this is mainly used for inline maps in the
// struct codec.
func (mc *mapCodec) encodeMapElements(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
keys := val.MapKeys()
for _, key := range keys {
keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt)
if err != nil {
return err
}
if collisionFn != nil && collisionFn(keyStr) {
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
}
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.MapIndex(key))
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}
vw, err := dw.WriteDocumentElement(keyStr)
if err != nil {
return err
}
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return nil
}
// DecodeValue is the ValueDecoder for map[string/decimal]* types.
func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
switch vrType := vr.Type(); vrType {
case Type(0), TypeEmbeddedDocument:
case TypeNull:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case TypeUndefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
default:
return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeMap(val.Type()))
}
if val.Len() > 0 && (mc.decodeZerosMap || dc.zeroMaps) {
clearMap(val)
}
eType := val.Type().Elem()
decoder, err := dc.LookupDecoder(eType)
if err != nil {
return err
}
keyType := val.Type().Key()
for {
key, vr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
return err
}
k, err := mc.decodeKey(key, keyType)
if err != nil {
return err
}
elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, eType)
if err != nil {
return newDecodeError(key, err)
}
val.SetMapIndex(k, elem)
}
return nil
}
func clearMap(m reflect.Value) {
var none reflect.Value
for _, k := range m.MapKeys() {
m.SetMapIndex(k, none)
}
}
func (mc *mapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) {
if mc.encodeKeysWithStringer || encodeKeysWithStringer {
return fmt.Sprint(val), nil
}
// keys of any string type are used directly
if val.Kind() == reflect.String {
return val.String(), nil
}
// KeyMarshalers are marshaled
if km, ok := val.Interface().(KeyMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}
buf, err := km.MarshalKey()
if err == nil {
return buf, nil
}
return "", err
}
// keys implement encoding.TextMarshaler are marshaled.
if km, ok := val.Interface().(encoding.TextMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}
buf, err := km.MarshalText()
if err != nil {
return "", err
}
return string(buf), nil
}
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(val.Int(), 10), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(val.Uint(), 10), nil
}
return "", fmt.Errorf("unsupported key type: %v", val.Type())
}
var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
func (mc *mapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
keyVal := reflect.ValueOf(key)
var err error
switch {
// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
case !mc.encodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(KeyUnmarshaler)
err = v.UnmarshalKey(key)
keyVal = keyVal.Elem()
// Try to decode encoding.TextUnmarshalers.
case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(encoding.TextUnmarshaler)
err = v.UnmarshalText([]byte(key))
keyVal = keyVal.Elem()
// Otherwise, go to type specific behavior
default:
switch keyType.Kind() {
case reflect.String:
keyVal = reflect.ValueOf(key).Convert(keyType)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, parseErr := strconv.ParseInt(key, 10, 64)
if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
err = fmt.Errorf("failed to unmarshal number key %v", key)
}
keyVal = reflect.ValueOf(n).Convert(keyType)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n, parseErr := strconv.ParseUint(key, 10, 64)
if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
err = fmt.Errorf("failed to unmarshal number key %v", key)
break
}
keyVal = reflect.ValueOf(n).Convert(keyType)
case reflect.Float32, reflect.Float64:
if mc.encodeKeysWithStringer {
parsed, err := strconv.ParseFloat(key, 64)
if err != nil {
return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err)
}
keyVal = reflect.ValueOf(parsed)
break
}
fallthrough
default:
return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
}
}
return keyVal, err
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"encoding/json"
"sync"
)
const defaultDstCap = 256
var extjPool = sync.Pool{
New: func() interface{} {
return new(extJSONValueWriter)
},
}
// Marshaler is the interface implemented by types that can marshal themselves
// into a valid BSON document.
//
// Implementations of Marshaler must return a full BSON document. To create
// custom BSON marshaling behavior for individual values in a BSON document,
// implement the ValueMarshaler interface instead.
type Marshaler interface {
MarshalBSON() ([]byte, error)
}
// ValueMarshaler is the interface implemented by types that can marshal
// themselves into a valid BSON value. The format of the returned bytes must
// match the returned type.
//
// Implementations of ValueMarshaler must return an individual BSON value. To
// create custom BSON marshaling behavior for an entire BSON document, implement
// the Marshaler interface instead.
type ValueMarshaler interface {
MarshalBSONValue() (typ byte, data []byte, err error)
}
// Pool of buffers for marshalling BSON.
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
// Marshal returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed into a
// document, MarshalValue should be used instead.
//
// Marshal will use the default registry created by NewRegistry to recursively
// marshal val into a []byte. Marshal will inspect struct tags and alter the
// marshaling process accordingly.
func Marshal(val interface{}) ([]byte, error) {
sw := bufPool.Get().(*bytes.Buffer)
defer func() {
// Proper usage of a sync.Pool requires each entry to have approximately
// the same memory cost. To obtain this property when the stored type
// contains a variably-sized buffer, we add a hard limit on the maximum
// buffer to place back in the pool. We limit the size to 16MiB because
// that's the maximum wire message size supported by any current MongoDB
// server.
//
// Comment based on
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
//
// Recycle byte slices that are smaller than 16MiB and at least half
// occupied.
if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() {
bufPool.Put(sw)
}
}()
sw.Reset()
vw := NewDocumentWriter(sw)
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vw)
enc.SetRegistry(defaultRegistry)
err := enc.Encode(val)
if err != nil {
return nil, err
}
buf := append([]byte(nil), sw.Bytes()...)
return buf, nil
}
// MarshalValue returns the BSON encoding of val.
//
// MarshalValue will use bson.NewRegistry() to transform val into a BSON value. If val is a struct, this function will
// inspect struct tags and alter the marshalling process accordingly.
func MarshalValue(val interface{}) (Type, []byte, error) {
sw := bufPool.Get().(*bytes.Buffer)
defer func() {
// Proper usage of a sync.Pool requires each entry to have approximately
// the same memory cost. To obtain this property when the stored type
// contains a variably-sized buffer, we add a hard limit on the maximum
// buffer to place back in the pool. We limit the size to 16MiB because
// that's the maximum wire message size supported by any current MongoDB
// server.
//
// Comment based on
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147
//
// Recycle byte slices that are smaller than 16MiB and at least half
// occupied.
if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() {
bufPool.Put(sw)
}
}()
sw.Reset()
vwFlusher := newDocumentWriter(sw)
vw, err := vwFlusher.WriteDocumentElement("")
if err != nil {
return 0, nil, err
}
// get an Encoder and encode the value
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vw)
enc.SetRegistry(defaultRegistry)
if err := enc.Encode(val); err != nil {
return 0, nil, err
}
// flush the bytes written because we cannot guarantee that a full document has been written
// after the flush, *sw will be in the format
// [value type, 0 (null byte to indicate end of empty element name), value bytes..]
if err := vwFlusher.Flush(); err != nil {
return 0, nil, err
}
typ := sw.Next(2)
clone := append([]byte{}, sw.Bytes()...) // Don't hand out a shared reference to byte buffer bytes
// and fully copy the data. The byte buffer is (potentially) reused
// and handing out only a reference to the bytes may lead to race-conditions with the buffer.
return Type(typ[0]), clone, nil
}
// MarshalExtJSON returns the extended JSON encoding of val.
func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) {
sw := sliceWriter(make([]byte, 0, defaultDstCap))
ejvw := extjPool.Get().(*extJSONValueWriter)
ejvw.reset(sw, canonical, escapeHTML)
ejvw.w = &sw
defer func() {
ejvw.buf = nil
ejvw.w = nil
extjPool.Put(ejvw)
}()
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(ejvw)
enc.ec = EncodeContext{Registry: defaultRegistry}
err := enc.Encode(val)
if err != nil {
return nil, err
}
return sw, nil
}
// IndentExtJSON will prefix and indent the provided extended JSON src and append it to dst.
func IndentExtJSON(dst *bytes.Buffer, src []byte, prefix, indent string) error {
return json.Indent(dst, src, prefix, indent)
}
// MarshalExtJSONIndent returns the extended JSON encoding of val with each line with prefixed
// and indented.
func MarshalExtJSONIndent(val interface{}, canonical, escapeHTML bool, prefix, indent string) ([]byte, error) {
marshaled, err := MarshalExtJSON(val, canonical, escapeHTML)
if err != nil {
return nil, err
}
var buf bytes.Buffer
err = IndentExtJSON(&buf, marshaled, prefix, indent)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"reflect"
)
var (
// ErrMgoSetZero may be returned from a SetBSON method to have the value set to its respective zero value.
ErrMgoSetZero = errors.New("set to zero")
tInt = reflect.TypeOf(int(0))
tM = reflect.TypeOf(M{})
tInterfaceSlice = reflect.TypeOf([]interface{}{})
tGetter = reflect.TypeOf((*getter)(nil)).Elem()
tSetter = reflect.TypeOf((*setter)(nil)).Elem()
)
// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders.
func NewMgoRegistry() *Registry {
mapCodec := &mapCodec{
decodeZerosMap: true,
encodeNilAsEmpty: true,
encodeKeysWithStringer: true,
}
structCodec := &structCodec{
inlineMapEncoder: mapCodec,
decodeZeroStruct: true,
encodeOmitDefaultStruct: true,
allowUnexportedFields: true,
}
uintCodec := &uintCodec{encodeToMinSize: true}
reg := NewRegistry()
reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true})
reg.RegisterKindDecoder(reflect.String, ValueDecoderFunc(mgoStringDecodeValue))
reg.RegisterKindDecoder(reflect.Struct, structCodec)
reg.RegisterKindDecoder(reflect.Map, mapCodec)
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true})
reg.RegisterKindEncoder(reflect.Struct, structCodec)
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true})
reg.RegisterKindEncoder(reflect.Map, mapCodec)
reg.RegisterKindEncoder(reflect.Uint, uintCodec)
reg.RegisterKindEncoder(reflect.Uint8, uintCodec)
reg.RegisterKindEncoder(reflect.Uint16, uintCodec)
reg.RegisterKindEncoder(reflect.Uint32, uintCodec)
reg.RegisterKindEncoder(reflect.Uint64, uintCodec)
reg.RegisterTypeMapEntry(TypeInt32, tInt)
reg.RegisterTypeMapEntry(TypeDateTime, tTime)
reg.RegisterTypeMapEntry(TypeArray, tInterfaceSlice)
reg.RegisterTypeMapEntry(Type(0), tM)
reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tM)
reg.RegisterInterfaceEncoder(tGetter, ValueEncoderFunc(getterEncodeValue))
reg.RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(setterDecodeValue))
return reg
}
// NewRespectNilValuesMgoRegistry creates a new bson.Registry configured to behave like mgo/bson
// with RespectNilValues set to true.
func NewRespectNilValuesMgoRegistry() *Registry {
mapCodec := &mapCodec{
decodeZerosMap: true,
}
reg := NewMgoRegistry()
reg.RegisterKindDecoder(reflect.Map, mapCodec)
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false})
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{})
reg.RegisterKindEncoder(reflect.Map, mapCodec)
return reg
}
func mgoStringDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueDecoderError{
Name: "StringDecodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: reflect.Zero(val.Type()),
}
}
if vr.Type() == TypeObjectID {
oid, err := vr.ReadObjectID()
if err != nil {
return err
}
if dc.objectIDAsHexString {
val.SetString(oid.Hex())
} else {
val.SetString(string(oid[:]))
}
return nil
}
return (&stringCodec{}).DecodeValue(dc, vr, val)
}
// setter interface: a value implementing the bson.Setter interface will receive the BSON
// value via the SetBSON method during unmarshaling, and the object
// itself will not be changed as usual.
//
// If setting the value works, the method should return nil or alternatively
// ErrMgoSetZero to set the respective field to its zero value (nil for
// pointer types). If SetBSON returns a non-nil error, the unmarshalling
// procedure will stop and error out with the provided value.
//
// This interface is generally useful in pointer receivers, since the method
// will want to change the receiver. A type field that implements the Setter
// interface doesn't have to be a pointer, though.
//
// For example:
//
// type MyString string
//
// func (s *MyString) SetBSON(raw bson.RawValue) error {
// return raw.Unmarshal(s)
// }
type setter interface {
SetBSON(raw RawValue) error
}
// getter interface: a value implementing the bson.Getter interface will have its GetBSON
// method called when the given value has to be marshalled, and the result
// of this method will be marshaled in place of the actual object.
//
// If GetBSON returns return a non-nil error, the marshalling procedure
// will stop and error out with the provided value.
type getter interface {
GetBSON() (interface{}, error)
}
// setterDecodeValue is the ValueDecoderFunc for Setter types.
func setterDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.IsValid() || (!val.Type().Implements(tSetter) && !reflect.PtrTo(val.Type()).Implements(tSetter)) {
return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
if val.Kind() == reflect.Ptr && val.IsNil() {
if !val.CanSet() {
return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
val.Set(reflect.New(val.Type().Elem()))
}
if !val.Type().Implements(tSetter) {
if !val.CanAddr() {
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
val = val.Addr() // If the type doesn't implement the interface, a pointer to it must.
}
t, src, err := copyValueToBytes(vr)
if err != nil {
return err
}
m, ok := val.Interface().(setter)
if !ok {
return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
if err := m.SetBSON(RawValue{Type: t, Value: src}); err != nil {
if !errors.Is(err, ErrMgoSetZero) {
return err
}
val.Set(reflect.Zero(val.Type()))
}
return nil
}
// getterEncodeValue is the ValueEncoderFunc for Getter types.
func getterEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Getter
switch {
case !val.IsValid():
return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
case val.Type().Implements(tGetter):
// If Getter is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tGetter) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tGetter) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
}
m, ok := val.Interface().(getter)
if !ok {
return vw.WriteNull()
}
x, err := m.GetBSON()
if err != nil {
return err
}
if x == nil {
return vw.WriteNull()
}
vv := reflect.ValueOf(x)
encoder, err := ec.Registry.LookupEncoder(vv.Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, vv)
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
)
type mode int
const (
_ mode = iota
mTopLevel
mDocument
mArray
mValue
mElement
mCodeWithScope
mSpacer
)
func (m mode) String() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "Document"
case mArray:
str = "Array"
case mValue:
str = "Value"
case mElement:
str = "Element"
case mCodeWithScope:
str = "CodeWithScope"
case mSpacer:
str = "CodeWithScopeSpacer"
default:
str = "Unknown"
}
return str
}
// TransitionError is an error returned when an invalid progressing a
// ValueReader or ValueWriter state machine occurs.
type TransitionError struct {
name string
parent mode
current mode
destination mode
modes []mode
action string
}
func (te TransitionError) Error() string {
errString := fmt.Sprintf("%s can only %s", te.name, te.action)
if te.destination != mode(0) {
errString = fmt.Sprintf("%s a %s", errString, te.destination)
}
errString = fmt.Sprintf("%s while positioned on a", errString)
for ind, m := range te.modes {
if ind != 0 && len(te.modes) > 2 {
errString = fmt.Sprintf("%s,", errString)
}
if ind == len(te.modes)-1 && len(te.modes) > 1 {
errString = fmt.Sprintf("%s or", errString)
}
errString = fmt.Sprintf("%s %s", errString, m)
}
errString = fmt.Sprintf("%s but is positioned on a %s", errString, te.current)
if te.parent != mode(0) {
errString = fmt.Sprintf("%s with parent %s", errString, te.parent)
}
return errString
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import (
"crypto/rand"
"encoding"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"sync/atomic"
"time"
)
// ErrInvalidHex indicates that a hex string cannot be converted to an ObjectID.
var ErrInvalidHex = errors.New("the provided hex string is not a valid ObjectID")
// ObjectID is the BSON ObjectID type.
type ObjectID [12]byte
// NilObjectID is the zero value for ObjectID.
var NilObjectID ObjectID
var objectIDCounter = readRandomUint32()
var processUnique = processUniqueBytes()
var _ encoding.TextMarshaler = ObjectID{}
var _ encoding.TextUnmarshaler = &ObjectID{}
// NewObjectID generates a new ObjectID.
func NewObjectID() ObjectID {
return NewObjectIDFromTimestamp(time.Now())
}
// NewObjectIDFromTimestamp generates a new ObjectID based on the given time.
func NewObjectIDFromTimestamp(timestamp time.Time) ObjectID {
var b [12]byte
binary.BigEndian.PutUint32(b[0:4], uint32(timestamp.Unix()))
copy(b[4:9], processUnique[:])
putUint24(b[9:12], atomic.AddUint32(&objectIDCounter, 1))
return b
}
// Timestamp extracts the time part of the ObjectId.
func (id ObjectID) Timestamp() time.Time {
unixSecs := binary.BigEndian.Uint32(id[0:4])
return time.Unix(int64(unixSecs), 0).UTC()
}
// Hex returns the hex encoding of the ObjectID as a string.
func (id ObjectID) Hex() string {
var buf [24]byte
hex.Encode(buf[:], id[:])
return string(buf[:])
}
func (id ObjectID) String() string {
return `ObjectID("` + id.Hex() + `")`
}
// IsZero returns true if id is the empty ObjectID.
func (id ObjectID) IsZero() bool {
return id == NilObjectID
}
// ObjectIDFromHex creates a new ObjectID from a hex string. It returns an error if the hex string is not a
// valid ObjectID.
func ObjectIDFromHex(s string) (ObjectID, error) {
if len(s) != 24 {
return NilObjectID, ErrInvalidHex
}
var oid [12]byte
_, err := hex.Decode(oid[:], []byte(s))
if err != nil {
return NilObjectID, err
}
return oid, nil
}
// MarshalText returns the ObjectID as UTF-8-encoded text. Implementing this allows us to use ObjectID
// as a map key when marshalling JSON. See https://pkg.go.dev/encoding#TextMarshaler
func (id ObjectID) MarshalText() ([]byte, error) {
var buf [24]byte
hex.Encode(buf[:], id[:])
return buf[:], nil
}
// UnmarshalText populates the byte slice with the ObjectID. If the byte slice
// is 24 bytes long, it will be populated with the hex representation of the
// ObjectID. If the byte slice is 12 bytes long, it will be populated with the
// BSON representation of the ObjectID. This method also accepts empty strings
// and decodes them as NilObjectID.
//
// For any other inputs, an error will be returned.
//
// Implementing this allows us to use ObjectID as a map key when unmarshalling
// JSON. See https://pkg.go.dev/encoding#TextUnmarshaler
func (id *ObjectID) UnmarshalText(b []byte) error {
// NB(charlie): The json package will use UnmarshalText instead of
// UnmarshalJSON if the value is a string.
// An empty string is not a valid ObjectID, but we treat it as a
// special value that decodes as NilObjectID.
if len(b) == 0 {
return nil
}
oid, err := ObjectIDFromHex(string(b))
if err != nil {
return err
}
*id = oid
return nil
}
// MarshalJSON returns the ObjectID as a string
func (id ObjectID) MarshalJSON() ([]byte, error) {
var buf [26]byte
buf[0] = '"'
hex.Encode(buf[1:25], id[:])
buf[25] = '"'
return buf[:], nil
}
// UnmarshalJSON populates the byte slice with the ObjectID. If the byte slice
// is 24 bytes long, it will be populated with the hex representation of the
// ObjectID. If the byte slice is 12 bytes long, it will be populated with the
// BSON representation of the ObjectID. This method also accepts empty strings
// and decodes them as NilObjectID.
//
// As a special case UnmarshalJSON will decode a JSON object with key "$oid"
// that stores a hex encoded ObjectID: {"$oid": "65b3f7edd9bfca00daa6e3b31"}.
//
// For any other inputs, an error will be returned.
func (id *ObjectID) UnmarshalJSON(b []byte) error {
// Ignore "null" to keep parity with the standard library. Decoding a JSON
// null into a non-pointer ObjectID field will leave the field unchanged.
// For pointer values, encoding/json will set the pointer to nil and will
// not enter the UnmarshalJSON hook.
if string(b) == "null" {
return nil
}
// Handle string
if len(b) >= 2 && b[0] == '"' {
// TODO: fails because of error
return id.UnmarshalText(b[1 : len(b)-1])
}
if len(b) == 12 {
copy(id[:], b)
return nil
}
var v struct {
OID *string `json:"$oid"`
}
if err := json.Unmarshal(b, &v); err != nil {
return fmt.Errorf("failed to parse extended JSON ObjectID: %w", err)
}
if v.OID == nil {
return errors.New("not an extended JSON ObjectID")
}
i, err := ObjectIDFromHex(*v.OID)
if err != nil {
return err
}
*id = i
return nil
}
func processUniqueBytes() [5]byte {
var b [5]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err))
}
return b
}
func readRandomUint32() uint32 {
var b [4]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err))
}
return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
}
func putUint24(b []byte, v uint32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
)
var (
_ ValueEncoder = &pointerCodec{}
_ ValueDecoder = &pointerCodec{}
)
// pointerCodec is the Codec used for pointers.
type pointerCodec struct {
ecache typeEncoderCache
dcache typeDecoderCache
}
// EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil
// or looking up an encoder for the type of value the pointer points to.
func (pc *pointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.Ptr {
if !val.IsValid() {
return vw.WriteNull()
}
return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
typ := val.Type()
if v, ok := pc.ecache.Load(typ); ok {
if v == nil {
return errNoEncoder{Type: typ}
}
return v.EncodeValue(ec, vw, val.Elem())
}
// TODO(charlie): handle concurrent requests for the same type
enc, err := ec.LookupEncoder(typ.Elem())
enc = pc.ecache.LoadOrStore(typ, enc)
if err != nil {
return err
}
return enc.EncodeValue(ec, vw, val.Elem())
}
// DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and
// using that to decode. If the BSON value is Null, this method will set the pointer to nil.
func (pc *pointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Ptr {
return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
typ := val.Type()
if vr.Type() == TypeNull {
val.Set(reflect.Zero(typ))
return vr.ReadNull()
}
if vr.Type() == TypeUndefined {
val.Set(reflect.Zero(typ))
return vr.ReadUndefined()
}
if val.IsNil() {
val.Set(reflect.New(typ.Elem()))
}
if v, ok := pc.dcache.Load(typ); ok {
if v == nil {
return errNoDecoder{Type: typ}
}
return v.DecodeValue(dc, vr, val.Elem())
}
// TODO(charlie): handle concurrent requests for the same type
dec, err := dc.LookupDecoder(typ.Elem())
dec = pc.dcache.LoadOrStore(typ, dec)
if err != nil {
return err
}
return dec.DecodeValue(dc, vr, val.Elem())
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"time"
)
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// The following primitive types are similar to Go primitives for BSON types that
// do not have direct Go primitive representations.
// Binary represents a BSON binary value.
type Binary struct {
Subtype byte
Data []byte
}
// Equal compares bp to bp2 and returns true if they are equal.
func (bp Binary) Equal(bp2 Binary) bool {
if bp.Subtype != bp2.Subtype {
return false
}
return bytes.Equal(bp.Data, bp2.Data)
}
// IsZero returns if bp is the empty Binary.
func (bp Binary) IsZero() bool {
return bp.Subtype == 0 && len(bp.Data) == 0
}
// Undefined represents the BSON undefined value type.
type Undefined struct{}
// DateTime represents the BSON datetime value.
type DateTime int64
var _ json.Marshaler = DateTime(0)
var _ json.Unmarshaler = (*DateTime)(nil)
// MarshalJSON marshal to time type.
func (d DateTime) MarshalJSON() ([]byte, error) {
return d.Time().UTC().MarshalJSON()
}
// UnmarshalJSON creates a bson.DateTime from a JSON string.
func (d *DateTime) UnmarshalJSON(data []byte) error {
// Ignore "null" so that we can distinguish between a "null" value and
// valid value that is the zero time (as reported by time.Time.IsZero).
if string(data) == "null" {
return nil
}
var t time.Time
if err := t.UnmarshalJSON(data); err != nil {
return err
}
*d = NewDateTimeFromTime(t)
return nil
}
// Time returns the date as a time type.
func (d DateTime) Time() time.Time {
return time.Unix(int64(d)/1000, int64(d)%1000*1000000)
}
// NewDateTimeFromTime creates a new DateTime from a Time.
func NewDateTimeFromTime(t time.Time) DateTime {
return DateTime(t.Unix()*1e3 + int64(t.Nanosecond())/1e6)
}
// Null represents the BSON null value.
type Null struct{}
// Regex represents a BSON regex value.
type Regex struct {
Pattern string
Options string
}
func (rp Regex) String() string {
return fmt.Sprintf(`{"pattern": "%s", "options": "%s"}`, rp.Pattern, rp.Options)
}
// Equal compares rp to rp2 and returns true if they are equal.
func (rp Regex) Equal(rp2 Regex) bool {
return rp.Pattern == rp2.Pattern && rp.Options == rp2.Options
}
// IsZero returns if rp is the empty Regex.
func (rp Regex) IsZero() bool {
return rp.Pattern == "" && rp.Options == ""
}
// DBPointer represents a BSON dbpointer value.
type DBPointer struct {
DB string
Pointer ObjectID
}
func (d DBPointer) String() string {
return fmt.Sprintf(`{"db": "%s", "pointer": "%s"}`, d.DB, d.Pointer)
}
// Equal compares d to d2 and returns true if they are equal.
func (d DBPointer) Equal(d2 DBPointer) bool {
return d == d2
}
// IsZero returns if d is the empty DBPointer.
func (d DBPointer) IsZero() bool {
return d.DB == "" && d.Pointer.IsZero()
}
// JavaScript represents a BSON JavaScript code value.
type JavaScript string
// Symbol represents a BSON symbol value.
type Symbol string
// CodeWithScope represents a BSON JavaScript code with scope value.
type CodeWithScope struct {
Code JavaScript
Scope interface{}
}
func (cws CodeWithScope) String() string {
return fmt.Sprintf(`{"code": "%s", "scope": %v}`, cws.Code, cws.Scope)
}
// Timestamp represents a BSON timestamp value.
type Timestamp struct {
T uint32
I uint32
}
// After reports whether the time instant tp is after tp2.
func (tp Timestamp) After(tp2 Timestamp) bool {
return tp.T > tp2.T || (tp.T == tp2.T && tp.I > tp2.I)
}
// Before reports whether the time instant tp is before tp2.
func (tp Timestamp) Before(tp2 Timestamp) bool {
return tp.T < tp2.T || (tp.T == tp2.T && tp.I < tp2.I)
}
// Equal compares tp to tp2 and returns true if they are equal.
func (tp Timestamp) Equal(tp2 Timestamp) bool {
return tp.T == tp2.T && tp.I == tp2.I
}
// IsZero returns if tp is the zero Timestamp.
func (tp Timestamp) IsZero() bool {
return tp.T == 0 && tp.I == 0
}
// Compare compares the time instant tp with tp2. If tp is before tp2, it returns -1; if tp is after
// tp2, it returns +1; if they're the same, it returns 0.
func (tp Timestamp) Compare(tp2 Timestamp) int {
switch {
case tp.Equal(tp2):
return 0
case tp.Before(tp2):
return -1
default:
return +1
}
}
// MinKey represents the BSON minkey value.
type MinKey struct{}
// MaxKey represents the BSON maxkey value.
type MaxKey struct{}
// D is an ordered representation of a BSON document. This type should be used when the order of the elements matters,
// such as MongoDB command documents. If the order of the elements does not matter, an M should be used instead.
//
// Example usage:
//
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
type D []E
func (d D) String() string {
b, err := MarshalExtJSON(d, true, false)
if err != nil {
return ""
}
return string(b)
}
// MarshalJSON encodes D into JSON.
func (d D) MarshalJSON() ([]byte, error) {
if d == nil {
return json.Marshal(nil)
}
var err error
var buf bytes.Buffer
buf.Write([]byte("{"))
enc := json.NewEncoder(&buf)
for i, e := range d {
err = enc.Encode(e.Key)
if err != nil {
return nil, err
}
buf.Write([]byte(":"))
err = enc.Encode(e.Value)
if err != nil {
return nil, err
}
if i < len(d)-1 {
buf.Write([]byte(","))
}
}
buf.Write([]byte("}"))
return json.RawMessage(buf.Bytes()).MarshalJSON()
}
// UnmarshalJSON decodes D from JSON.
func (d *D) UnmarshalJSON(b []byte) error {
dec := json.NewDecoder(bytes.NewReader(b))
t, err := dec.Token()
if err != nil {
return err
}
if t == nil {
*d = nil
return nil
}
if v, ok := t.(json.Delim); !ok || v != '{' {
return &json.UnmarshalTypeError{
Value: tokenString(t),
Type: reflect.TypeOf(D(nil)),
Offset: dec.InputOffset(),
}
}
*d, err = jsonDecodeD(dec)
return err
}
// E represents a BSON element for a D. It is usually used inside a D.
type E struct {
Key string
Value interface{}
}
// M is an unordered representation of a BSON document. This type should be used when the order of the elements does not
// matter. This type is handled as a regular map[string]interface{} when encoding and decoding. Elements will be
// serialized in an undefined, random order. If the order of the elements matters, a D should be used instead.
//
// Example usage:
//
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
type M map[string]interface{}
func (m M) String() string {
b, err := MarshalExtJSON(m, true, false)
if err != nil {
return ""
}
return string(b)
}
// An A is an ordered representation of a BSON array.
//
// Example usage:
//
// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}}
type A []interface{}
func jsonDecodeD(dec *json.Decoder) (D, error) {
res := D{}
for {
var e E
t, err := dec.Token()
if err != nil {
return nil, err
}
key, ok := t.(string)
if !ok {
break
}
e.Key = key
t, err = dec.Token()
if err != nil {
return nil, err
}
switch v := t.(type) {
case json.Delim:
switch v {
case '[':
e.Value, err = jsonDecodeSlice(dec)
if err != nil {
return nil, err
}
case '{':
e.Value, err = jsonDecodeD(dec)
if err != nil {
return nil, err
}
}
default:
e.Value = t
}
res = append(res, e)
}
return res, nil
}
func jsonDecodeSlice(dec *json.Decoder) ([]interface{}, error) {
var res []interface{}
done := false
for !done {
t, err := dec.Token()
if err != nil {
return nil, err
}
switch v := t.(type) {
case json.Delim:
switch v {
case '[':
a, err := jsonDecodeSlice(dec)
if err != nil {
return nil, err
}
res = append(res, a)
case '{':
d, err := jsonDecodeD(dec)
if err != nil {
return nil, err
}
res = append(res, d)
default:
done = true
}
default:
res = append(res, t)
}
}
return res, nil
}
func tokenString(t json.Token) string {
switch v := t.(type) {
case json.Delim:
switch v {
case '{':
return "object"
case '[':
return "array"
}
case bool:
return "bool"
case float64:
return "number"
case json.Number, string:
return "string"
}
return "unknown"
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
)
var tRawValue = reflect.TypeOf(RawValue{})
var tRaw = reflect.TypeOf(Raw(nil))
// registerPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs
// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created.
func registerPrimitiveCodecs(reg *Registry) {
reg.RegisterTypeEncoder(tRawValue, ValueEncoderFunc(rawValueEncodeValue))
reg.RegisterTypeEncoder(tRaw, ValueEncoderFunc(rawEncodeValue))
reg.RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue))
reg.RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue))
}
// rawValueEncodeValue is the ValueEncoderFunc for RawValue.
//
// If the RawValue's Type is "invalid" and the RawValue's Value is not empty or
// nil, then this method will return an error.
func rawValueEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRawValue {
return ValueEncoderError{
Name: "RawValueEncodeValue",
Types: []reflect.Type{tRawValue},
Received: val,
}
}
rawvalue := val.Interface().(RawValue)
if !rawvalue.Type.IsValid() {
return fmt.Errorf("the RawValue Type specifies an invalid BSON type: %#x", byte(rawvalue.Type))
}
return copyValueFromBytes(vw, rawvalue.Type, rawvalue.Value)
}
// rawValueDecodeValue is the ValueDecoderFunc for RawValue.
func rawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tRawValue {
return ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val}
}
t, value, err := copyValueToBytes(vr)
if err != nil {
return err
}
val.Set(reflect.ValueOf(RawValue{Type: t, Value: value}))
return nil
}
// rawEncodeValue is the ValueEncoderFunc for Reader.
func rawEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRaw {
return ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val}
}
rdr := val.Interface().(Raw)
return copyDocumentFromBytes(vw, rdr)
}
// rawDecodeValue is the ValueDecoderFunc for Reader.
func rawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tRaw {
return ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
rdr, err := appendDocumentBytes(val.Interface().(Raw), vr)
val.Set(reflect.ValueOf(rdr))
return err
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"io"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// ErrNilReader indicates that an operation was attempted on a nil bson.Reader.
var ErrNilReader = errors.New("nil reader")
// Raw is a raw encoded BSON document. It can be used to delay BSON document decoding or precompute
// a BSON encoded document.
//
// A Raw must be a full BSON document. Use the RawValue type for individual BSON values.
type Raw []byte
// ReadDocument reads a BSON document from the io.Reader and returns it as a bson.Raw. If the
// reader contains multiple BSON documents, only the first document is read.
func ReadDocument(r io.Reader) (Raw, error) {
doc, err := bsoncore.NewDocumentFromReader(r)
return Raw(doc), err
}
// Validate validates the document. This method only validates the first document in
// the slice, to validate other documents, the slice must be resliced.
func (r Raw) Validate() (err error) { return bsoncore.Document(r).Validate() }
// Lookup search the document, potentially recursively, for the given key. If
// there are multiple keys provided, this method will recurse down, as long as
// the top and intermediate nodes are either documents or arrays.If an error
// occurs or if the value doesn't exist, an empty RawValue is returned.
func (r Raw) Lookup(key ...string) RawValue {
return convertFromCoreValue(bsoncore.Document(r).Lookup(key...))
}
// LookupErr searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
func (r Raw) LookupErr(key ...string) (RawValue, error) {
val, err := bsoncore.Document(r).LookupErr(key...)
return convertFromCoreValue(val), err
}
// Elements returns this document as a slice of elements. The returned slice will contain valid
// elements. If the document is not valid, the elements up to the invalid point will be returned
// along with an error.
func (r Raw) Elements() ([]RawElement, error) {
doc := bsoncore.Document(r)
if len(doc) == 0 {
return nil, nil
}
elems, err := doc.Elements()
if err != nil {
return nil, err
}
relems := make([]RawElement, 0, len(elems))
for _, elem := range elems {
relems = append(relems, RawElement(elem))
}
return relems, nil
}
// Values returns this document as a slice of values. The returned slice will contain valid values.
// If the document is not valid, the values up to the invalid point will be returned along with an
// error.
func (r Raw) Values() ([]RawValue, error) {
vals, err := bsoncore.Document(r).Values()
rvals := make([]RawValue, 0, len(vals))
for _, val := range vals {
rvals = append(rvals, convertFromCoreValue(val))
}
return rvals, err
}
// Index searches for and retrieves the element at the given index. This method will panic if
// the document is invalid or if the index is out of bounds.
func (r Raw) Index(index uint) RawElement { return RawElement(bsoncore.Document(r).Index(index)) }
// IndexErr searches for and retrieves the element at the given index.
func (r Raw) IndexErr(index uint) (RawElement, error) {
elem, err := bsoncore.Document(r).IndexErr(index)
return RawElement(elem), err
}
// String returns the BSON document encoded as Extended JSON.
func (r Raw) String() string { return bsoncore.Document(r).String() }
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"io"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// RawArray is a raw bytes representation of a BSON array.
type RawArray []byte
// ReadArray reads a BSON array from the io.Reader and returns it as a
// bson.RawArray.
func ReadArray(r io.Reader) (RawArray, error) {
doc, err := bsoncore.NewArrayFromReader(r)
return RawArray(doc), err
}
// Index searches for and retrieves the value at the given index. This method
// will panic if the array is invalid or if the index is out of bounds.
func (a RawArray) Index(index uint) RawValue {
return convertFromCoreValue(bsoncore.Array(a).Index(index))
}
// IndexErr searches for and retrieves the value at the given index.
func (a RawArray) IndexErr(index uint) (RawValue, error) {
elem, err := bsoncore.Array(a).IndexErr(index)
return convertFromCoreValue(elem), err
}
// DebugString outputs a human readable version of Array. It will attempt to
// stringify the valid components of the array even if the entire array is not
// valid.
func (a RawArray) DebugString() string {
return bsoncore.Array(a).DebugString()
}
// String outputs an ExtendedJSON version of Array. If the Array is not valid,
// this method returns an empty string.
func (a RawArray) String() string {
return bsoncore.Array(a).String()
}
// Values returns this array as a slice of values. The returned slice will
// contain valid values. If the array is not valid, the values up to the invalid
// point will be returned along with an error.
func (a RawArray) Values() ([]RawValue, error) {
vals, err := bsoncore.Array(a).Values()
if err != nil {
return nil, err
}
rvals := make([]RawValue, 0, len(vals))
for _, val := range vals {
rvals = append(rvals, convertFromCoreValue(val))
}
return rvals, err
}
// Validate validates the array and ensures the elements contained within are
// valid.
func (a RawArray) Validate() error {
return bsoncore.Array(a).Validate()
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// RawElement is a raw encoded BSON document or array element.
type RawElement []byte
// Key returns the key for this element. If the element is not valid, this method returns an empty
// string. If knowing if the element is valid is important, use KeyErr.
func (re RawElement) Key() string { return bsoncore.Element(re).Key() }
// KeyErr returns the key for this element, returning an error if the element is not valid.
func (re RawElement) KeyErr() (string, error) { return bsoncore.Element(re).KeyErr() }
// Value returns the value of this element. If the element is not valid, this method returns an
// empty Value. If knowing if the element is valid is important, use ValueErr.
func (re RawElement) Value() RawValue { return convertFromCoreValue(bsoncore.Element(re).Value()) }
// ValueErr returns the value for this element, returning an error if the element is not valid.
func (re RawElement) ValueErr() (RawValue, error) {
val, err := bsoncore.Element(re).ValueErr()
return convertFromCoreValue(val), err
}
// Validate ensures re is a valid BSON element.
func (re RawElement) Validate() error { return bsoncore.Element(re).Validate() }
// String returns the BSON element encoded as Extended JSON.
func (re RawElement) String() string {
doc := bsoncore.BuildDocument(nil, re)
j, err := MarshalExtJSON(Raw(doc), true, false)
if err != nil {
return "<malformed>"
}
return string(j)
}
// DebugString outputs a human readable version of RawElement. It will attempt to stringify the
// valid components of the element even if the entire element is not valid.
func (re RawElement) DebugString() string { return bsoncore.Element(re).DebugString() }
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// ErrNilContext is returned when the provided DecodeContext is nil.
var ErrNilContext = errors.New("DecodeContext cannot be nil")
// ErrNilRegistry is returned when the provided registry is nil.
var ErrNilRegistry = errors.New("Registry cannot be nil")
// RawValue is a raw encoded BSON value. It can be used to delay BSON value decoding or precompute
// BSON encoded value. Type is the BSON type of the value and Value is the raw encoded BSON value.
//
// A RawValue must be an individual BSON value. Use the Raw type for full BSON documents.
type RawValue struct {
Type Type
Value []byte
r *Registry
}
// IsZero reports whether the RawValue is zero, i.e. no data is present on
// the RawValue. It returns true if Type is 0 and Value is empty or nil.
func (rv RawValue) IsZero() bool {
return rv.Type == 0x00 && len(rv.Value) == 0
}
// Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an
// error is returned. This method will use the registry used to create the RawValue, if the RawValue
// was created from partial BSON processing, or it will use the default registry. Users wishing to
// specify the registry to use should use UnmarshalWithRegistry.
func (rv RawValue) Unmarshal(val interface{}) error {
reg := rv.r
if reg == nil {
reg = defaultRegistry
}
return rv.UnmarshalWithRegistry(reg, val)
}
// Equal compares rv and rv2 and returns true if they are equal.
func (rv RawValue) Equal(rv2 RawValue) bool {
if rv.Type != rv2.Type {
return false
}
if !bytes.Equal(rv.Value, rv2.Value) {
return false
}
return true
}
// UnmarshalWithRegistry performs the same unmarshalling as Unmarshal but uses the provided registry
// instead of the one attached or the default registry.
func (rv RawValue) UnmarshalWithRegistry(r *Registry, val interface{}) error {
if r == nil {
return ErrNilRegistry
}
vr := newValueReader(rv.Type, bytes.NewReader(rv.Value))
rval := reflect.ValueOf(val)
if rval.Kind() != reflect.Ptr {
return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval)
}
rval = rval.Elem()
dec, err := r.LookupDecoder(rval.Type())
if err != nil {
return err
}
return dec.DecodeValue(DecodeContext{Registry: r}, vr, rval)
}
// UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext
// instead of the one attached or the default registry.
func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val interface{}) error {
if dc == nil {
return ErrNilContext
}
vr := newValueReader(rv.Type, bytes.NewReader(rv.Value))
rval := reflect.ValueOf(val)
if rval.Kind() != reflect.Ptr {
return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval)
}
rval = rval.Elem()
dec, err := dc.LookupDecoder(rval.Type())
if err != nil {
return err
}
return dec.DecodeValue(*dc, vr, rval)
}
func convertFromCoreValue(v bsoncore.Value) RawValue {
return RawValue{Type: Type(v.Type), Value: v.Data}
}
func convertToCoreValue(v RawValue) bsoncore.Value {
return bsoncore.Value{Type: bsoncore.Type(v.Type), Data: v.Value}
}
// Validate ensures the value is a valid BSON value.
func (rv RawValue) Validate() error { return convertToCoreValue(rv).Validate() }
// IsNumber returns true if the type of v is a numeric BSON type.
func (rv RawValue) IsNumber() bool { return convertToCoreValue(rv).IsNumber() }
// String implements the fmt.String interface. This method will return values in extended JSON
// format. If the value is not valid, this returns an empty string
func (rv RawValue) String() string { return convertToCoreValue(rv).String() }
// DebugString outputs a human readable version of Document. It will attempt to stringify the
// valid components of the document even if the entire document is not valid.
func (rv RawValue) DebugString() string { return convertToCoreValue(rv).DebugString() }
// Double returns the float64 value for this element.
// It panics if e's BSON type is not bson.TypeDouble.
func (rv RawValue) Double() float64 { return convertToCoreValue(rv).Double() }
// DoubleOK is the same as Double, but returns a boolean instead of panicking.
func (rv RawValue) DoubleOK() (float64, bool) { return convertToCoreValue(rv).DoubleOK() }
// StringValue returns the string value for this element.
// It panics if e's BSON type is not bson.TypeString.
//
// NOTE: This method is called StringValue to avoid a collision with the String method which
// implements the fmt.Stringer interface.
func (rv RawValue) StringValue() string { return convertToCoreValue(rv).StringValue() }
// StringValueOK is the same as StringValue, but returns a boolean instead of
// panicking.
func (rv RawValue) StringValueOK() (string, bool) { return convertToCoreValue(rv).StringValueOK() }
// Document returns the BSON document the Value represents as a Document. It panics if the
// value is a BSON type other than document.
func (rv RawValue) Document() Raw { return Raw(convertToCoreValue(rv).Document()) }
// DocumentOK is the same as Document, except it returns a boolean
// instead of panicking.
func (rv RawValue) DocumentOK() (Raw, bool) {
doc, ok := convertToCoreValue(rv).DocumentOK()
return Raw(doc), ok
}
// Array returns the BSON array the Value represents as an Array. It panics if the
// value is a BSON type other than array.
func (rv RawValue) Array() RawArray { return RawArray(convertToCoreValue(rv).Array()) }
// ArrayOK is the same as Array, except it returns a boolean instead
// of panicking.
func (rv RawValue) ArrayOK() (RawArray, bool) {
doc, ok := convertToCoreValue(rv).ArrayOK()
return RawArray(doc), ok
}
// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type
// other than binary.
func (rv RawValue) Binary() (subtype byte, data []byte) { return convertToCoreValue(rv).Binary() }
// BinaryOK is the same as Binary, except it returns a boolean instead of
// panicking.
func (rv RawValue) BinaryOK() (subtype byte, data []byte, ok bool) {
return convertToCoreValue(rv).BinaryOK()
}
// ObjectID returns the BSON objectid value the Value represents. It panics if the value is a BSON
// type other than objectid.
func (rv RawValue) ObjectID() ObjectID { return convertToCoreValue(rv).ObjectID() }
// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of
// panicking.
func (rv RawValue) ObjectIDOK() (ObjectID, bool) {
return convertToCoreValue(rv).ObjectIDOK()
}
// Boolean returns the boolean value the Value represents. It panics if the
// value is a BSON type other than boolean.
func (rv RawValue) Boolean() bool { return convertToCoreValue(rv).Boolean() }
// BooleanOK is the same as Boolean, except it returns a boolean instead of
// panicking.
func (rv RawValue) BooleanOK() (bool, bool) { return convertToCoreValue(rv).BooleanOK() }
// DateTime returns the BSON datetime value the Value represents as a
// unix timestamp. It panics if the value is a BSON type other than datetime.
func (rv RawValue) DateTime() int64 { return convertToCoreValue(rv).DateTime() }
// DateTimeOK is the same as DateTime, except it returns a boolean instead of
// panicking.
func (rv RawValue) DateTimeOK() (int64, bool) { return convertToCoreValue(rv).DateTimeOK() }
// Time returns the BSON datetime value the Value represents. It panics if the value is a BSON
// type other than datetime.
func (rv RawValue) Time() time.Time { return convertToCoreValue(rv).Time() }
// TimeOK is the same as Time, except it returns a boolean instead of
// panicking.
func (rv RawValue) TimeOK() (time.Time, bool) { return convertToCoreValue(rv).TimeOK() }
// Regex returns the BSON regex value the Value represents. It panics if the value is a BSON
// type other than regex.
func (rv RawValue) Regex() (pattern, options string) { return convertToCoreValue(rv).Regex() }
// RegexOK is the same as Regex, except it returns a boolean instead of
// panicking.
func (rv RawValue) RegexOK() (pattern, options string, ok bool) {
return convertToCoreValue(rv).RegexOK()
}
// DBPointer returns the BSON dbpointer value the Value represents. It panics if the value is a BSON
// type other than DBPointer.
func (rv RawValue) DBPointer() (string, ObjectID) {
return convertToCoreValue(rv).DBPointer()
}
// DBPointerOK is the same as DBPoitner, except that it returns a boolean
// instead of panicking.
func (rv RawValue) DBPointerOK() (string, ObjectID, bool) {
return convertToCoreValue(rv).DBPointerOK()
}
// JavaScript returns the BSON JavaScript code value the Value represents. It panics if the value is
// a BSON type other than JavaScript code.
func (rv RawValue) JavaScript() string { return convertToCoreValue(rv).JavaScript() }
// JavaScriptOK is the same as Javascript, excepti that it returns a boolean
// instead of panicking.
func (rv RawValue) JavaScriptOK() (string, bool) { return convertToCoreValue(rv).JavaScriptOK() }
// Symbol returns the BSON symbol value the Value represents. It panics if the value is a BSON
// type other than symbol.
func (rv RawValue) Symbol() string { return convertToCoreValue(rv).Symbol() }
// SymbolOK is the same as Symbol, excepti that it returns a boolean
// instead of panicking.
func (rv RawValue) SymbolOK() (string, bool) { return convertToCoreValue(rv).SymbolOK() }
// CodeWithScope returns the BSON JavaScript code with scope the Value represents.
// It panics if the value is a BSON type other than JavaScript code with scope.
func (rv RawValue) CodeWithScope() (string, Raw) {
code, scope := convertToCoreValue(rv).CodeWithScope()
return code, Raw(scope)
}
// CodeWithScopeOK is the same as CodeWithScope, except that it returns a boolean instead of
// panicking.
func (rv RawValue) CodeWithScopeOK() (string, Raw, bool) {
code, scope, ok := convertToCoreValue(rv).CodeWithScopeOK()
return code, Raw(scope), ok
}
// Int32 returns the int32 the Value represents. It panics if the value is a BSON type other than
// int32.
func (rv RawValue) Int32() int32 { return convertToCoreValue(rv).Int32() }
// Int32OK is the same as Int32, except that it returns a boolean instead of
// panicking.
func (rv RawValue) Int32OK() (int32, bool) { return convertToCoreValue(rv).Int32OK() }
// Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a
// BSON type other than timestamp.
func (rv RawValue) Timestamp() (t, i uint32) { return convertToCoreValue(rv).Timestamp() }
// TimestampOK is the same as Timestamp, except that it returns a boolean
// instead of panicking.
func (rv RawValue) TimestampOK() (t, i uint32, ok bool) { return convertToCoreValue(rv).TimestampOK() }
// Int64 returns the int64 the Value represents. It panics if the value is a BSON type other than
// int64.
func (rv RawValue) Int64() int64 { return convertToCoreValue(rv).Int64() }
// Int64OK is the same as Int64, except that it returns a boolean instead of
// panicking.
func (rv RawValue) Int64OK() (int64, bool) { return convertToCoreValue(rv).Int64OK() }
// AsInt64 returns a BSON number as an int64. If the BSON type is not a numeric one, this method
// will panic.
func (rv RawValue) AsInt64() int64 { return convertToCoreValue(rv).AsInt64() }
// AsInt64OK is the same as AsInt64, except that it returns a boolean instead of
// panicking.
func (rv RawValue) AsInt64OK() (int64, bool) { return convertToCoreValue(rv).AsInt64OK() }
// Decimal128 returns the decimal the Value represents. It panics if the value is a BSON type other than
// decimal.
func (rv RawValue) Decimal128() Decimal128 { return NewDecimal128(convertToCoreValue(rv).Decimal128()) }
// Decimal128OK is the same as Decimal128, except that it returns a boolean
// instead of panicking.
func (rv RawValue) Decimal128OK() (Decimal128, bool) {
h, l, ok := convertToCoreValue(rv).Decimal128OK()
return NewDecimal128(h, l), ok
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
"sync"
)
// defaultRegistry is the default Registry. It contains the default codecs and the
// primitive codecs.
var defaultRegistry = NewRegistry()
// errNoEncoder is returned when there wasn't an encoder available for a type.
type errNoEncoder struct {
Type reflect.Type
}
func (ene errNoEncoder) Error() string {
if ene.Type == nil {
return "no encoder found for <nil>"
}
return "no encoder found for " + ene.Type.String()
}
// errNoDecoder is returned when there wasn't a decoder available for a type.
type errNoDecoder struct {
Type reflect.Type
}
func (end errNoDecoder) Error() string {
return "no decoder found for " + end.Type.String()
}
// errNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type.
type errNoTypeMapEntry struct {
Type Type
}
func (entme errNoTypeMapEntry) Error() string {
return "no type map entry found for " + entme.Type.String()
}
// A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type
// documentation for examples of registering various custom encoders and decoders. A Registry can
// have four main types of codecs:
//
// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and
// RegisterTypeDecoder methods. The registered codec will be invoked when encoding/decoding a value
// whose type matches the registered type exactly.
// If the registered type is an interface, the codec will be invoked when encoding or decoding
// values whose type is the interface, but not for values with concrete types that implement the
// interface.
//
// 2. Interface encoders/decoders - These can be registered using the RegisterInterfaceEncoder and
// RegisterInterfaceDecoder methods. These methods only accept interface types and the registered codecs
// will be invoked when encoding or decoding values whose types implement the interface. An example
// of an interface defined by the driver is bson.Marshaler. The driver will call the MarshalBSON method
// for any value whose type implements bson.Marshaler, regardless of the value's concrete type.
//
// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type
// associations are used when decoding into a bson.D/bson.M or a struct field of type interface{}.
// For example, by default, BSON int32 and int64 values decode as Go int32 and int64 instances,
// respectively, when decoding into a bson.D. The following code would change the behavior so these
// values decode as Go int instances instead:
//
// intType := reflect.TypeOf(int(0))
// registry.RegisterTypeMapEntry(bson.TypeInt32, intType).RegisterTypeMapEntry(bson.TypeInt64, intType)
//
// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and
// RegisterDefaultDecoder methods. The registered codec will be invoked when encoding or decoding
// values whose reflect.Kind matches the registered reflect.Kind as long as the value's type doesn't
// match a registered type or interface encoder/decoder first. These methods should be used to change the
// behavior for all values for a specific kind.
//
// Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure.
type Registry struct {
interfaceEncoders []interfaceValueEncoder
interfaceDecoders []interfaceValueDecoder
typeEncoders *typeEncoderCache
typeDecoders *typeDecoderCache
kindEncoders *kindEncoderCache
kindDecoders *kindDecoderCache
typeMap sync.Map // map[Type]reflect.Type
}
// NewRegistry creates a new empty Registry.
func NewRegistry() *Registry {
reg := &Registry{
typeEncoders: new(typeEncoderCache),
typeDecoders: new(typeDecoderCache),
kindEncoders: new(kindEncoderCache),
kindDecoders: new(kindDecoderCache),
}
registerDefaultEncoders(reg)
registerDefaultDecoders(reg)
registerPrimitiveCodecs(reg)
return reg
}
// RegisterTypeEncoder registers the provided ValueEncoder for the provided type.
//
// The type will be used as provided, so an encoder can be registered for a type and a different
// encoder can be registered for a pointer to that type.
//
// If the given type is an interface, the encoder will be called when marshaling a type that is
// that interface. It will not be called when marshaling a non-interface type that implements the
// interface. To get the latter behavior, call RegisterHookEncoder instead.
//
// RegisterTypeEncoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) {
r.typeEncoders.Store(valueType, enc)
}
// RegisterTypeDecoder registers the provided ValueDecoder for the provided type.
//
// The type will be used as provided, so a decoder can be registered for a type and a different
// decoder can be registered for a pointer to that type.
//
// If the given type is an interface, the decoder will be called when unmarshaling into a type that
// is that interface. It will not be called when unmarshaling into a non-interface type that
// implements the interface. To get the latter behavior, call RegisterHookDecoder instead.
//
// RegisterTypeDecoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) {
r.typeDecoders.Store(valueType, dec)
}
// RegisterKindEncoder registers the provided ValueEncoder for the provided kind.
//
// Use RegisterKindEncoder to register an encoder for any type with the same underlying kind. For
// example, consider the type MyInt defined as
//
// type MyInt int32
//
// To define an encoder for MyInt and int32, use RegisterKindEncoder like
//
// reg.RegisterKindEncoder(reflect.Int32, myEncoder)
//
// RegisterKindEncoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) {
r.kindEncoders.Store(kind, enc)
}
// RegisterKindDecoder registers the provided ValueDecoder for the provided kind.
//
// Use RegisterKindDecoder to register a decoder for any type with the same underlying kind. For
// example, consider the type MyInt defined as
//
// type MyInt int32
//
// To define an decoder for MyInt and int32, use RegisterKindDecoder like
//
// reg.RegisterKindDecoder(reflect.Int32, myDecoder)
//
// RegisterKindDecoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) {
r.kindDecoders.Store(kind, dec)
}
// RegisterInterfaceEncoder registers an encoder for the provided interface type iface. This encoder will
// be called when marshaling a type if the type implements iface or a pointer to the type
// implements iface. If the provided type is not an interface
// (i.e. iface.Kind() != reflect.Interface), this method will panic.
//
// RegisterInterfaceEncoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder) {
if iface.Kind() != reflect.Interface {
panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+
"got type %s with kind %s", iface, iface.Kind())
panic(panicStr)
}
for idx, encoder := range r.interfaceEncoders {
if encoder.i == iface {
r.interfaceEncoders[idx].ve = enc
return
}
}
r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{i: iface, ve: enc})
}
// RegisterInterfaceDecoder registers an decoder for the provided interface type iface. This decoder will
// be called when unmarshaling into a type if the type implements iface or a pointer to the type
// implements iface. If the provided type is not an interface (i.e. iface.Kind() != reflect.Interface),
// this method will panic.
//
// RegisterInterfaceDecoder should not be called concurrently with any other Registry method.
func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) {
if iface.Kind() != reflect.Interface {
panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+
"got type %s with kind %s", iface, iface.Kind())
panic(panicStr)
}
for idx, decoder := range r.interfaceDecoders {
if decoder.i == iface {
r.interfaceDecoders[idx].vd = dec
return
}
}
r.interfaceDecoders = append(r.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec})
}
// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this
// mapping is decoding situations where an empty interface is used and a default type needs to be
// created and decoded into.
//
// By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON
// documents, a type map entry for TypeEmbeddedDocument should be registered. For example, to force BSON documents
// to decode to bson.Raw, use the following code:
//
// reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{}))
func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) {
r.typeMap.Store(bt, rt)
}
// LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup
// order:
//
// 1. An encoder registered for the exact type. If the given type is an interface, an encoder
// registered using RegisterTypeEncoder for that interface will be selected.
//
// 2. An encoder registered using RegisterInterfaceEncoder for an interface implemented by the type
// or by a pointer to the type. If the value matches multiple interfaces (e.g. the type implements
// bson.Marshaler and bson.ValueMarshaler), the first one registered will be selected.
// Note that registries constructed using bson.NewRegistry have driver-defined interfaces registered
// for the bson.Marshaler, bson.ValueMarshaler, and bson.Proxy interfaces, so those will take
// precedence over any new interfaces.
//
// 3. An encoder registered using RegisterKindEncoder for the kind of value.
//
// If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for
// concurrent use by multiple goroutines after all codecs and encoders are registered.
func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) {
if valueType == nil {
return nil, errNoEncoder{Type: valueType}
}
enc, found := r.lookupTypeEncoder(valueType)
if found {
if enc == nil {
return nil, errNoEncoder{Type: valueType}
}
return enc, nil
}
enc, found = r.lookupInterfaceEncoder(valueType, true)
if found {
return r.typeEncoders.LoadOrStore(valueType, enc), nil
}
if v, ok := r.kindEncoders.Load(valueType.Kind()); ok {
return r.storeTypeEncoder(valueType, v), nil
}
return nil, errNoEncoder{Type: valueType}
}
func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder {
return r.typeEncoders.LoadOrStore(rt, enc)
}
func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) {
return r.typeEncoders.Load(rt)
}
func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) {
if valueType == nil {
return nil, false
}
for _, ienc := range r.interfaceEncoders {
if valueType.Implements(ienc.i) {
return ienc.ve, true
}
if allowAddr && valueType.Kind() != reflect.Ptr && reflect.PtrTo(valueType).Implements(ienc.i) {
// if *t implements an interface, this will catch if t implements an interface further
// ahead in interfaceEncoders
defaultEnc, found := r.lookupInterfaceEncoder(valueType, false)
if !found {
defaultEnc, _ = r.kindEncoders.Load(valueType.Kind())
}
return newCondAddrEncoder(ienc.ve, defaultEnc), true
}
}
return nil, false
}
// LookupDecoder returns the first matching decoder in the Registry. It uses the following lookup
// order:
//
// 1. A decoder registered for the exact type. If the given type is an interface, a decoder
// registered using RegisterTypeDecoder for that interface will be selected.
//
// 2. A decoder registered using RegisterInterfaceDecoder for an interface implemented by the type or by
// a pointer to the type. If the value matches multiple interfaces (e.g. the type implements
// bson.Unmarshaler and bson.ValueUnmarshaler), the first one registered will be selected.
// Note that registries constructed using bson.NewRegistry have driver-defined interfaces registered
// for the bson.Unmarshaler and bson.ValueUnmarshaler interfaces, so those will take
// precedence over any new interfaces.
//
// 3. A decoder registered using RegisterKindDecoder for the kind of value.
//
// If no decoder is found, an error of type ErrNoDecoder is returned. LookupDecoder is safe for
// concurrent use by multiple goroutines after all codecs and decoders are registered.
func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) {
if valueType == nil {
return nil, errors.New("cannot perform a decoder lookup on <nil>")
}
dec, found := r.lookupTypeDecoder(valueType)
if found {
if dec == nil {
return nil, errNoDecoder{Type: valueType}
}
return dec, nil
}
dec, found = r.lookupInterfaceDecoder(valueType, true)
if found {
return r.storeTypeDecoder(valueType, dec), nil
}
if v, ok := r.kindDecoders.Load(valueType.Kind()); ok {
return r.storeTypeDecoder(valueType, v), nil
}
return nil, errNoDecoder{Type: valueType}
}
func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) {
return r.typeDecoders.Load(valueType)
}
func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder {
return r.typeDecoders.LoadOrStore(typ, dec)
}
func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) {
for _, idec := range r.interfaceDecoders {
if valueType.Implements(idec.i) {
return idec.vd, true
}
if allowAddr && valueType.Kind() != reflect.Ptr && reflect.PtrTo(valueType).Implements(idec.i) {
// if *t implements an interface, this will catch if t implements an interface further
// ahead in interfaceDecoders
defaultDec, found := r.lookupInterfaceDecoder(valueType, false)
if !found {
defaultDec, _ = r.kindDecoders.Load(valueType.Kind())
}
return newCondAddrDecoder(idec.vd, defaultDec), true
}
}
return nil, false
}
// LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON
// type. If no type is found, ErrNoTypeMapEntry is returned.
//
// LookupTypeMapEntry should not be called concurrently with any other Registry method.
func (r *Registry) LookupTypeMapEntry(bt Type) (reflect.Type, error) {
v, ok := r.typeMap.Load(bt)
if v == nil || !ok {
return nil, errNoTypeMapEntry{Type: bt}
}
return v.(reflect.Type), nil
}
type interfaceValueEncoder struct {
i reflect.Type
ve ValueEncoder
}
type interfaceValueDecoder struct {
i reflect.Type
vd ValueDecoder
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
)
// sliceCodec is the Codec used for slice values.
type sliceCodec struct {
// encodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of
// BSON null.
encodeNilAsEmpty bool
}
// EncodeValue is the ValueEncoder for slice types.
func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Slice {
return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
if val.IsNil() && !sc.encodeNilAsEmpty && !ec.nilSliceAsEmpty {
return vw.WriteNull()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
byteSlice := make([]byte, val.Len())
reflect.Copy(reflect.ValueOf(byteSlice), val)
return vw.WriteBinary(byteSlice)
}
// If we have a []E we want to treat it as a document instead of as an array.
if val.Type() == tD || val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(D)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for _, e := range d {
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// DecodeValue is the ValueDecoder for slice types.
func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Slice {
return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
switch vrType := vr.Type(); vrType {
case TypeArray:
case TypeNull:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case TypeUndefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
case Type(0), TypeEmbeddedDocument:
if val.Type().Elem() != tE {
return fmt.Errorf("cannot decode document into %s", val.Type())
}
case TypeBinary:
if val.Type().Elem() != tByte {
return fmt.Errorf("SliceDecodeValue can only decode a binary into a byte array, got %v", vrType)
}
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return fmt.Errorf("SliceDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", TypeBinary, subtype)
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(data)))
}
val.SetLen(0)
val.Set(reflect.AppendSlice(val, reflect.ValueOf(data)))
return nil
case TypeString:
if sliceType := val.Type().Elem(); sliceType != tByte {
return fmt.Errorf("SliceDecodeValue can only decode a string into a byte array, got %v", sliceType)
}
str, err := vr.ReadString()
if err != nil {
return err
}
byteStr := []byte(str)
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(byteStr)))
}
val.SetLen(0)
val.Set(reflect.AppendSlice(val, reflect.ValueOf(byteStr)))
return nil
default:
return fmt.Errorf("cannot decode %v into a slice", vrType)
}
var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error)
switch val.Type().Elem() {
case tE:
elemsFunc = decodeD
default:
elemsFunc = decodeDefault
}
elems, err := elemsFunc(dc, vr, val)
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(elems)))
}
val.SetLen(0)
val.Set(reflect.Append(val, elems...))
return nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
)
// stringCodec is the Codec used for string values.
type stringCodec struct{}
// Assert that stringCodec satisfies the typeDecoder interface, which allows it to be
// used by collection type decoders (e.g. map, slice, etc) to set individual values in a
// collection.
var _ typeDecoder = &stringCodec{}
// EncodeValue is the ValueEncoder for string types.
func (sc *stringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{
Name: "StringEncodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: val,
}
}
return vw.WriteString(val.String())
}
func (sc *stringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t.Kind() != reflect.String {
return emptyValue, ValueDecoderError{
Name: "StringDecodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: reflect.Zero(t),
}
}
var str string
var err error
switch vr.Type() {
case TypeString:
str, err = vr.ReadString()
if err != nil {
return emptyValue, err
}
case TypeObjectID:
if dc.objectIDAsHexString {
oid, err := vr.ReadObjectID()
if err != nil {
return emptyValue, err
}
str = oid.Hex()
} else {
const msg = "decoding an object ID into a string is not supported by default " +
"(set Decoder.ObjectIDAsHexString to enable decoding as a hexadecimal string)"
return emptyValue, errors.New(msg)
}
case TypeSymbol:
str, err = vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
case TypeBinary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "string"}
}
str = string(data)
case TypeNull:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into a string type", vr.Type())
}
return reflect.ValueOf(str), nil
}
// DecodeValue is the ValueDecoder for string types.
func (sc *stringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.String {
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
elem, err := sc.decodeType(dctx, vr, val.Type())
if err != nil {
return err
}
val.SetString(elem.String())
return nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"time"
)
// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
type DecodeError struct {
keys []string
wrapped error
}
// Unwrap returns the underlying error
func (de *DecodeError) Unwrap() error {
return de.wrapped
}
// Error implements the error interface.
func (de *DecodeError) Error() string {
// The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the
// stack of BSON keys, so we call de.Keys(), which reverses them.
keyPath := strings.Join(de.Keys(), ".")
return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
}
// Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down
// order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be
// a string, the keys slice will be ["a", "b", "c"].
func (de *DecodeError) Keys() []string {
reversedKeys := make([]string, 0, len(de.keys))
for idx := len(de.keys) - 1; idx >= 0; idx-- {
reversedKeys = append(reversedKeys, de.keys[idx])
}
return reversedKeys
}
// mapElementsEncoder handles encoding of the values of an inline map.
type mapElementsEncoder interface {
encodeMapElements(EncodeContext, DocumentWriter, reflect.Value, func(string) bool) error
}
// structCodec is the Codec used for struct values.
type structCodec struct {
cache sync.Map // map[reflect.Type]*structDescription
inlineMapEncoder mapElementsEncoder
// decodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the
// destination value passed to Decode before unmarshaling BSON documents into them.
decodeZeroStruct bool
// decodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the
// destination value passed to Decode before unmarshaling BSON documents into them.
decodeDeepZeroInline bool
// encodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g.
// MyStruct{}) as empty and omit it from the marshaled BSON when the "omitempty" struct tag
// option is set.
encodeOmitDefaultStruct bool
// allowUnexportedFields allows encoding and decoding values from un-exported struct fields.
allowUnexportedFields bool
// overwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is
// a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The
// default value is true.
overwriteDuplicatedInlinedFields bool
}
var (
_ ValueEncoder = &structCodec{}
_ ValueDecoder = &structCodec{}
)
// newStructCodec returns a StructCodec that uses p for struct tag parsing.
func newStructCodec(elemEncoder mapElementsEncoder) *structCodec {
return &structCodec{
inlineMapEncoder: elemEncoder,
overwriteDuplicatedInlinedFields: true,
}
}
// EncodeValue handles encoding generic struct types.
func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Struct {
return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates)
if err != nil {
return err
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
var rv reflect.Value
for _, desc := range sd.fl {
if desc.inline == nil {
rv = val.Field(desc.idx)
} else {
rv, err = fieldByIndexErr(val, desc.inline)
if err != nil {
continue
}
}
if ec.omitEmpty {
desc.omitEmpty = true
}
desc.encoder, rv, err = lookupElementEncoder(ec, desc.encoder, rv)
if err != nil && !errors.Is(err, errInvalidValue) {
return err
}
if errors.Is(err, errInvalidValue) {
if desc.omitEmpty {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
err = vw2.WriteNull()
if err != nil {
return err
}
continue
}
if desc.encoder == nil {
return errNoEncoder{Type: rv.Type()}
}
encoder := desc.encoder
var empty bool
if rv.Kind() == reflect.Interface {
// isEmpty will not treat an interface rv as an interface, so we need to check for the
// nil interface separately.
empty = rv.IsNil()
} else {
empty = isEmpty(rv, sc.encodeOmitDefaultStruct || ec.omitZeroStruct)
}
if desc.omitEmpty && empty {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
ectx := EncodeContext{
Registry: ec.Registry,
minSize: desc.minSize || ec.minSize,
errorOnInlineDuplicates: ec.errorOnInlineDuplicates,
stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt,
nilMapAsEmpty: ec.nilMapAsEmpty,
nilSliceAsEmpty: ec.nilSliceAsEmpty,
nilByteSliceAsEmpty: ec.nilByteSliceAsEmpty,
omitZeroStruct: ec.omitZeroStruct,
useJSONStructTags: ec.useJSONStructTags,
}
err = encoder.EncodeValue(ectx, vw2, rv)
if err != nil {
return err
}
}
if sd.inlineMap >= 0 {
rv := val.Field(sd.inlineMap)
collisionFn := func(key string) bool {
_, exists := sd.fm[key]
return exists
}
err = sc.inlineMapEncoder.encodeMapElements(ec, dw, rv, collisionFn)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
func newDecodeError(key string, original error) error {
var de *DecodeError
if !errors.As(original, &de) {
return &DecodeError{
keys: []string{key},
wrapped: original,
}
}
de.keys = append(de.keys, key)
return de
}
// DecodeValue implements the Codec interface.
// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Struct {
return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
switch vrType := vr.Type(); vrType {
case Type(0), TypeEmbeddedDocument:
case TypeNull:
if err := vr.ReadNull(); err != nil {
return err
}
val.Set(reflect.Zero(val.Type()))
return nil
case TypeUndefined:
if err := vr.ReadUndefined(); err != nil {
return err
}
val.Set(reflect.Zero(val.Type()))
return nil
default:
return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
}
sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false)
if err != nil {
return err
}
if sc.decodeZeroStruct || dc.zeroStructs {
val.Set(reflect.Zero(val.Type()))
}
if sc.decodeDeepZeroInline && sd.inline {
val.Set(deepZero(val.Type()))
}
var decoder ValueDecoder
var inlineMap reflect.Value
if sd.inlineMap >= 0 {
inlineMap = val.Field(sd.inlineMap)
decoder, err = dc.LookupDecoder(inlineMap.Type().Elem())
if err != nil {
return err
}
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
for {
name, vr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
return err
}
fd, exists := sd.fm[name]
if !exists {
// if the original name isn't found in the struct description, try again with the name in lowercase
// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
// names
fd, exists = sd.fm[strings.ToLower(name)]
}
if !exists {
if sd.inlineMap < 0 {
// The encoding/json package requires a flag to return on error for non-existent fields.
// This functionality seems appropriate for the struct codec.
err = vr.Skip()
if err != nil {
return err
}
continue
}
if inlineMap.IsNil() {
inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
}
elem := reflect.New(inlineMap.Type().Elem()).Elem()
err = decoder.DecodeValue(dc, vr, elem)
if err != nil {
return err
}
inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
continue
}
var field reflect.Value
if fd.inline == nil {
field = val.Field(fd.idx)
} else {
field, err = getInlineField(val, fd.inline)
if err != nil {
return err
}
}
if field.Kind() == reflect.Interface && !field.IsNil() && field.Elem().Kind() == reflect.Ptr {
v := field.Elem().Elem()
decoder, err = dc.LookupDecoder(v.Type())
if err != nil {
return err
}
err = decoder.DecodeValue(dc, vr, v)
if err != nil {
return newDecodeError(fd.name, err)
}
continue
}
if !field.CanSet() { // Being settable is a super set of being addressable.
innerErr := fmt.Errorf("field %v is not settable", field)
return newDecodeError(fd.name, innerErr)
}
if field.Kind() == reflect.Ptr && field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Addr()
dctx := DecodeContext{
Registry: dc.Registry,
truncate: fd.truncate || dc.truncate,
defaultDocumentType: dc.defaultDocumentType,
binaryAsSlice: dc.binaryAsSlice,
objectIDAsHexString: dc.objectIDAsHexString,
useJSONStructTags: dc.useJSONStructTags,
useLocalTimeZone: dc.useLocalTimeZone,
zeroMaps: dc.zeroMaps,
zeroStructs: dc.zeroStructs,
}
if fd.decoder == nil {
return newDecodeError(fd.name, errNoDecoder{Type: field.Elem().Type()})
}
err = fd.decoder.DecodeValue(dctx, vr, field.Elem())
if err != nil {
return newDecodeError(fd.name, err)
}
}
return nil
}
func isEmpty(v reflect.Value, omitZeroStruct bool) bool {
kind := v.Kind()
if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) {
return v.Interface().(Zeroer).IsZero()
}
switch kind {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Struct:
if !omitZeroStruct {
return false
}
vt := v.Type()
if vt == tTime {
return v.Interface().(time.Time).IsZero()
}
numField := vt.NumField()
for i := 0; i < numField; i++ {
ff := vt.Field(i)
if ff.PkgPath != "" && !ff.Anonymous {
continue // Private field
}
if !isEmpty(v.Field(i), omitZeroStruct) {
return false
}
}
return true
}
return !v.IsValid() || v.IsZero()
}
type structDescription struct {
fm map[string]fieldDescription
fl []fieldDescription
inlineMap int
inline bool
}
type fieldDescription struct {
name string // BSON key name
fieldName string // struct field name
idx int
omitEmpty bool
minSize bool
truncate bool
inline []int
encoder ValueEncoder
decoder ValueDecoder
}
type byIndex []fieldDescription
func (bi byIndex) Len() int { return len(bi) }
func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
func (bi byIndex) Less(i, j int) bool {
// If a field is inlined, its index in the top level struct is stored at inline[0]
iIdx, jIdx := bi[i].idx, bi[j].idx
if len(bi[i].inline) > 0 {
iIdx = bi[i].inline[0]
}
if len(bi[j].inline) > 0 {
jIdx = bi[j].inline[0]
}
if iIdx != jIdx {
return iIdx < jIdx
}
for k, biik := range bi[i].inline {
if k >= len(bi[j].inline) {
return false
}
if biik != bi[j].inline[k] {
return biik < bi[j].inline[k]
}
}
return len(bi[i].inline) < len(bi[j].inline)
}
func (sc *structCodec) describeStruct(
r *Registry,
t reflect.Type,
useJSONStructTags bool,
errorOnDuplicates bool,
) (*structDescription, error) {
// We need to analyze the struct, including getting the tags, collecting
// information about inlining, and create a map of the field name to the field.
if v, ok := sc.cache.Load(t); ok {
return v.(*structDescription), nil
}
// TODO(charlie): Only describe the struct once when called
// concurrently with the same type.
ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates)
if err != nil {
return nil, err
}
if v, loaded := sc.cache.LoadOrStore(t, ds); loaded {
ds = v.(*structDescription)
}
return ds, nil
}
func (sc *structCodec) describeStructSlow(
r *Registry,
t reflect.Type,
useJSONStructTags bool,
errorOnDuplicates bool,
) (*structDescription, error) {
numFields := t.NumField()
sd := &structDescription{
fm: make(map[string]fieldDescription, numFields),
fl: make([]fieldDescription, 0, numFields),
inlineMap: -1,
}
var fields []fieldDescription
for i := 0; i < numFields; i++ {
sf := t.Field(i)
if sf.PkgPath != "" && (!sc.allowUnexportedFields || !sf.Anonymous) {
// field is private or unexported fields aren't allowed, ignore
continue
}
sfType := sf.Type
encoder, err := r.LookupEncoder(sfType)
if err != nil {
encoder = nil
}
decoder, err := r.LookupDecoder(sfType)
if err != nil {
decoder = nil
}
description := fieldDescription{
fieldName: sf.Name,
idx: i,
encoder: encoder,
decoder: decoder,
}
var stags *structTags
// If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser
// instead of the parser defined on the codec.
if useJSONStructTags {
stags, err = parseJSONStructTags(sf)
} else {
stags, err = parseStructTags(sf)
}
if err != nil {
return nil, err
}
if stags.Skip {
continue
}
description.name = stags.Name
description.omitEmpty = stags.OmitEmpty
description.minSize = stags.MinSize
description.truncate = stags.Truncate
if stags.Inline {
sd.inline = true
switch sfType.Kind() {
case reflect.Map:
if sd.inlineMap >= 0 {
return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
}
if sfType.Key() != tString {
return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
}
sd.inlineMap = description.idx
case reflect.Ptr:
sfType = sfType.Elem()
if sfType.Kind() != reflect.Struct {
return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
}
fallthrough
case reflect.Struct:
inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates)
if err != nil {
return nil, err
}
for _, fd := range inlinesf.fl {
if fd.inline == nil {
fd.inline = []int{i, fd.idx}
} else {
fd.inline = append([]int{i}, fd.inline...)
}
fields = append(fields, fd)
}
default:
return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
}
continue
}
fields = append(fields, description)
}
// Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name
sort.Slice(fields, func(i, j int) bool {
x := fields
// sort field by name, breaking ties with depth, then
// breaking ties with index sequence.
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].inline) != len(x[j].inline) {
return len(x[i].inline) < len(x[j].inline)
}
return byIndex(x).Less(i, j)
})
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
sd.fl = append(sd.fl, fi)
sd.fm[name] = fi
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if !ok || !sc.overwriteDuplicatedInlinedFields || errorOnDuplicates {
return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name)
}
sd.fl = append(sd.fl, dominant)
sd.fm[name] = dominant
}
sort.Sort(byIndex(sd.fl))
return sd, nil
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's inlining rules. If there are multiple top-level
// fields, the boolean will be false: This condition is an error in Go
// and we skip all the fields.
func dominantField(fields []fieldDescription) (fieldDescription, bool) {
// The fields are sorted in increasing index-length order, then by presence of tag.
// That means that the first field is the dominant one. We need only check
// for error cases: two fields at top level.
if len(fields) > 1 &&
len(fields[0].inline) == len(fields[1].inline) {
return fieldDescription{}, false
}
return fields[0], true
}
func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
defer func() {
if recovered := recover(); recovered != nil {
switch r := recovered.(type) {
case string:
err = fmt.Errorf("%s", r)
case error:
err = r
}
}
}()
result = v.FieldByIndex(index)
return
}
func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
field, err := fieldByIndexErr(val, index)
if err == nil {
return field, nil
}
// if parent of this element doesn't exist, fix its parent
inlineParent := index[:len(index)-1]
var fParent reflect.Value
if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
fParent, err = getInlineField(val, inlineParent)
if err != nil {
return fParent, err
}
}
fParent.Set(reflect.New(fParent.Type().Elem()))
return fieldByIndexErr(val, index)
}
// DeepZero returns recursive zero object
func deepZero(st reflect.Type) (result reflect.Value) {
if st.Kind() == reflect.Struct {
numField := st.NumField()
for i := 0; i < numField; i++ {
if result == emptyValue {
result = reflect.Indirect(reflect.New(st))
}
f := result.Field(i)
if f.CanInterface() {
if f.Type().Kind() == reflect.Struct {
result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem())))
}
}
}
}
return result
}
// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
func recursivePointerTo(v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
result := reflect.New(v.Type())
if v.Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
if f := v.Field(i); f.Kind() == reflect.Ptr {
if f.Elem().Kind() == reflect.Struct {
result.Elem().Field(i).Set(recursivePointerTo(f))
}
}
}
}
return result
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"strings"
)
// structTags represents the struct tag fields that the StructCodec uses during
// the encoding and decoding process.
//
// In the case of a struct, the lowercased field name is used as the key for each exported
// field but this behavior may be changed using a struct tag. The tag may also contain flags to
// adjust the marshalling behavior for the field.
//
// The properties are defined below:
//
// OmitEmpty Only include the field if it's not set to the zero value for the type or to
// empty slices or maps.
//
// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's
// feasible while preserving the numeric value.
//
// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within
// a float32.
//
// Inline Inline the field, which must be a struct or a map, causing all of its fields
// or keys to be processed as if they were part of the outer struct. For maps,
// keys must not conflict with the bson keys of other struct fields.
//
// Skip This struct field should be skipped. This is usually denoted by parsing a "-"
// for the name.
type structTags struct {
Name string
OmitEmpty bool
MinSize bool
Truncate bool
Inline bool
Skip bool
}
// DefaultStructTagParser is the StructTagParser used by the StructCodec by default.
// It will handle the bson struct tag. See the documentation for StructTags to see
// what each of the returned fields means.
//
// If there is no name in the struct tag fields, the struct field name is lowercased.
// The tag formats accepted are:
//
// "[<key>][,<flag1>[,<flag2>]]"
//
// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)`
//
// An example:
//
// type T struct {
// A bool
// B int "myb"
// C string "myc,omitempty"
// D string `bson:",omitempty" json:"jsonkey"`
// E int64 ",minsize"
// F int64 "myf,omitempty,minsize"
// }
//
// A struct tag either consisting entirely of '-' or with a bson key with a
// value consisting entirely of '-' will return a StructTags with Skip true and
// the remaining fields will be their default values.
func parseStructTags(sf reflect.StructField) (*structTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
return parseTags(key, tag)
}
// jsonStructTagParser has the same behavior as DefaultStructTagParser
// but will also fallback to parsing the json tag instead on a field where the
// bson tag isn't available.
func parseJSONStructTags(sf reflect.StructField) (*structTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok {
tag, ok = sf.Tag.Lookup("json")
}
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
return parseTags(key, tag)
}
func parseTags(key string, tag string) (*structTags, error) {
var st structTags
if tag == "-" {
st.Skip = true
return &st, nil
}
for idx, str := range strings.Split(tag, ",") {
if idx == 0 && str != "" {
key = str
}
switch str {
case "omitempty":
st.OmitEmpty = true
case "minsize":
st.MinSize = true
case "truncate":
st.Truncate = true
case "inline":
st.Inline = true
}
}
st.Name = key
return &st, nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
"time"
)
const (
timeFormatString = "2006-01-02T15:04:05.999Z07:00"
)
// timeCodec is the Codec used for time.Time values.
type timeCodec struct {
// useLocalTimeZone specifies if we should decode into the local time zone. Defaults to false.
useLocalTimeZone bool
}
// Assert that timeCodec satisfies the typeDecoder interface, which allows it to be used
// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection.
var _ typeDecoder = &timeCodec{}
func (tc *timeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tTime {
return emptyValue, ValueDecoderError{
Name: "TimeDecodeValue",
Types: []reflect.Type{tTime},
Received: reflect.Zero(t),
}
}
var timeVal time.Time
switch vrType := vr.Type(); vrType {
case TypeDateTime:
dt, err := vr.ReadDateTime()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(dt/1000, dt%1000*1000000)
case TypeString:
// assume strings are in the isoTimeFormat
timeStr, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
timeVal, err = time.Parse(timeFormatString, timeStr)
if err != nil {
return emptyValue, err
}
case TypeInt64:
i64, err := vr.ReadInt64()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(i64/1000, i64%1000*1000000)
case TypeTimestamp:
t, _, err := vr.ReadTimestamp()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(int64(t), 0)
case TypeNull:
if err := vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err := vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType)
}
if !tc.useLocalTimeZone && !dc.useLocalTimeZone {
timeVal = timeVal.UTC()
}
return reflect.ValueOf(timeVal), nil
}
// DecodeValue is the ValueDecoderFunc for time.Time.
func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tTime {
return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val}
}
elem, err := tc.decodeType(dc, vr, tTime)
if err != nil {
return err
}
val.Set(elem)
return nil
}
// EncodeValue is the ValueEncoderFunc for time.TIme.
func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTime {
return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val}
}
tt := val.Interface().(time.Time)
dt := NewDateTimeFromTime(tt)
return vw.WriteDateTime(int64(dt))
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/json"
"net/url"
"reflect"
"time"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
// Type represents a BSON type.
type Type byte
// String returns the string representation of the BSON type's name.
func (bt Type) String() string {
return bsoncore.Type(bt).String()
}
// IsValid will return true if the Type is valid.
func (bt Type) IsValid() bool {
switch bt {
case TypeDouble, TypeString, TypeEmbeddedDocument, TypeArray, TypeBinary,
TypeUndefined, TypeObjectID, TypeBoolean, TypeDateTime, TypeNull, TypeRegex,
TypeDBPointer, TypeJavaScript, TypeSymbol, TypeCodeWithScope, TypeInt32,
TypeTimestamp, TypeInt64, TypeDecimal128, TypeMinKey, TypeMaxKey:
return true
default:
return false
}
}
// BSON element types as described in https://bsonspec.org/spec.html.
const (
TypeDouble Type = 0x01
TypeString Type = 0x02
TypeEmbeddedDocument Type = 0x03
TypeArray Type = 0x04
TypeBinary Type = 0x05
TypeUndefined Type = 0x06
TypeObjectID Type = 0x07
TypeBoolean Type = 0x08
TypeDateTime Type = 0x09
TypeNull Type = 0x0A
TypeRegex Type = 0x0B
TypeDBPointer Type = 0x0C
TypeJavaScript Type = 0x0D
TypeSymbol Type = 0x0E
TypeCodeWithScope Type = 0x0F
TypeInt32 Type = 0x10
TypeTimestamp Type = 0x11
TypeInt64 Type = 0x12
TypeDecimal128 Type = 0x13
TypeMaxKey Type = 0x7F
TypeMinKey Type = 0xFF
)
// BSON binary element subtypes as described in https://bsonspec.org/spec.html.
const (
TypeBinaryGeneric byte = 0x00
TypeBinaryFunction byte = 0x01
TypeBinaryBinaryOld byte = 0x02
TypeBinaryUUIDOld byte = 0x03
TypeBinaryUUID byte = 0x04
TypeBinaryMD5 byte = 0x05
TypeBinaryEncrypted byte = 0x06
TypeBinaryColumn byte = 0x07
TypeBinarySensitive byte = 0x08
TypeBinaryVector byte = 0x09
TypeBinaryUserDefined byte = 0x80
)
var tBool = reflect.TypeOf(false)
var tFloat64 = reflect.TypeOf(float64(0))
var tInt32 = reflect.TypeOf(int32(0))
var tInt64 = reflect.TypeOf(int64(0))
var tString = reflect.TypeOf("")
var tTime = reflect.TypeOf(time.Time{})
var tEmpty = reflect.TypeOf((*interface{})(nil)).Elem()
var tByteSlice = reflect.TypeOf([]byte(nil))
var tByte = reflect.TypeOf(byte(0x00))
var tURL = reflect.TypeOf(url.URL{})
var tJSONNumber = reflect.TypeOf(json.Number(""))
var tValueMarshaler = reflect.TypeOf((*ValueMarshaler)(nil)).Elem()
var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem()
var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
var tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem()
var tBinary = reflect.TypeOf(Binary{})
var tUndefined = reflect.TypeOf(Undefined{})
var tOID = reflect.TypeOf(ObjectID{})
var tDateTime = reflect.TypeOf(DateTime(0))
var tNull = reflect.TypeOf(Null{})
var tRegex = reflect.TypeOf(Regex{})
var tCodeWithScope = reflect.TypeOf(CodeWithScope{})
var tDBPointer = reflect.TypeOf(DBPointer{})
var tJavaScript = reflect.TypeOf(JavaScript(""))
var tSymbol = reflect.TypeOf(Symbol(""))
var tTimestamp = reflect.TypeOf(Timestamp{})
var tDecimal = reflect.TypeOf(Decimal128{})
var tVector = reflect.TypeOf(Vector{})
var tMinKey = reflect.TypeOf(MinKey{})
var tMaxKey = reflect.TypeOf(MaxKey{})
var tD = reflect.TypeOf(D{})
var tA = reflect.TypeOf(A{})
var tE = reflect.TypeOf(E{})
var tCoreDocument = reflect.TypeOf(bsoncore.Document{})
var tCoreArray = reflect.TypeOf(bsoncore.Array{})
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"math"
"reflect"
)
// uintCodec is the Codec used for uint values.
type uintCodec struct {
// encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the
// minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value.
encodeToMinSize bool
}
// Assert that uintCodec satisfies the typeDecoder interface, which allows it to be used
// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection.
var _ typeDecoder = &uintCodec{}
// EncodeValue is the ValueEncoder for uint types.
func (uic *uintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Uint8, reflect.Uint16:
return vw.WriteInt32(int32(val.Uint()))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
u64 := val.Uint()
// If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32
useMinSize := ec.minSize || (uic.encodeToMinSize && val.Kind() != reflect.Uint64)
if u64 <= math.MaxInt32 && useMinSize {
return vw.WriteInt32(int32(u64))
}
if u64 > math.MaxInt64 {
return fmt.Errorf("%d overflows int64", u64)
}
return vw.WriteInt64(int64(u64))
}
return ValueEncoderError{
Name: "UintEncodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
func (uic *uintCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
var i64 int64
var err error
switch vrType := vr.Type(); vrType {
case TypeInt32:
i32, err := vr.ReadInt32()
if err != nil {
return emptyValue, err
}
i64 = int64(i32)
case TypeInt64:
i64, err = vr.ReadInt64()
if err != nil {
return emptyValue, err
}
case TypeDouble:
f64, err := vr.ReadDouble()
if err != nil {
return emptyValue, err
}
if !dc.truncate && math.Floor(f64) != f64 {
return emptyValue, errCannotTruncate
}
if f64 > float64(math.MaxInt64) {
return emptyValue, fmt.Errorf("%g overflows int64", f64)
}
i64 = int64(f64)
case TypeBoolean:
b, err := vr.ReadBoolean()
if err != nil {
return emptyValue, err
}
if b {
i64 = 1
}
case TypeNull:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case TypeUndefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType)
}
switch t.Kind() {
case reflect.Uint8:
if i64 < 0 || i64 > math.MaxUint8 {
return emptyValue, fmt.Errorf("%d overflows uint8", i64)
}
return reflect.ValueOf(uint8(i64)), nil
case reflect.Uint16:
if i64 < 0 || i64 > math.MaxUint16 {
return emptyValue, fmt.Errorf("%d overflows uint16", i64)
}
return reflect.ValueOf(uint16(i64)), nil
case reflect.Uint32:
if i64 < 0 || i64 > math.MaxUint32 {
return emptyValue, fmt.Errorf("%d overflows uint32", i64)
}
return reflect.ValueOf(uint32(i64)), nil
case reflect.Uint64:
if i64 < 0 {
return emptyValue, fmt.Errorf("%d overflows uint64", i64)
}
return reflect.ValueOf(uint64(i64)), nil
case reflect.Uint:
if i64 < 0 {
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}
v := uint64(i64)
if v > math.MaxUint { // Can we fit this inside of an uint
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}
return reflect.ValueOf(uint(v)), nil
default:
return emptyValue, ValueDecoderError{
Name: "UintDecodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: reflect.Zero(t),
}
}
}
// DecodeValue is the ValueDecoder for uint types.
func (uic *uintCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() {
return ValueDecoderError{
Name: "UintDecodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
elem, err := uic.decodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.SetUint(elem.Uint())
return nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"fmt"
)
// Unmarshaler is the interface implemented by types that can unmarshal a BSON
// document representation of themselves. The input can be assumed to be a valid
// encoding of a BSON document. UnmarshalBSON must copy the JSON data if it
// wishes to retain the data after returning.
//
// Unmarshaler is only used to unmarshal full BSON documents. To create custom
// BSON unmarshaling behavior for individual values in a BSON document,
// implement the ValueUnmarshaler interface instead.
type Unmarshaler interface {
UnmarshalBSON([]byte) error
}
// ValueUnmarshaler is the interface implemented by types that can unmarshal a
// BSON value representation of themselves. The input can be assumed to be a
// valid encoding of a BSON value. UnmarshalBSONValue must copy the BSON value
// bytes if it wishes to retain the data after returning.
//
// ValueUnmarshaler is only used to unmarshal individual values in a BSON
// document. To create custom BSON unmarshaling behavior for an entire BSON
// document, implement the Unmarshaler interface instead.
type ValueUnmarshaler interface {
UnmarshalBSONValue(typ byte, data []byte) error
}
// Unmarshal parses the BSON-encoded data and stores the result in the value
// pointed to by val. If val is nil or not a pointer, Unmarshal returns an
// error.
//
// When unmarshaling BSON, if the BSON value is null and the Go value is a
// pointer, the pointer is set to nil without calling UnmarshalBSONValue.
func Unmarshal(data []byte, val interface{}) error {
vr := newDocumentReader(bytes.NewReader(data))
if l, err := vr.peekLength(); err != nil {
return err
} else if int(l) != len(data) {
return fmt.Errorf("invalid document length")
}
return unmarshalFromReader(DecodeContext{Registry: defaultRegistry}, vr, val)
}
// UnmarshalValue parses the BSON value of type t with bson.NewRegistry() and
// stores the result in the value pointed to by val. If val is nil or not a pointer,
// UnmarshalValue returns an error.
func UnmarshalValue(t Type, data []byte, val interface{}) error {
vr := newValueReader(t, bytes.NewReader(data))
return unmarshalFromReader(DecodeContext{Registry: defaultRegistry}, vr, val)
}
// UnmarshalExtJSON parses the extended JSON-encoded data and stores the result
// in the value pointed to by val. If val is nil or not a pointer, UnmarshalExtJSON
// returns an error.
//
// If canonicalOnly is true, UnmarshalExtJSON returns an error if the Extended
// JSON was not marshaled in canonical mode.
//
// For more information about Extended JSON, see
// https://www.mongodb.com/docs/manual/reference/mongodb-extended-json/
func UnmarshalExtJSON(data []byte, canonicalOnly bool, val interface{}) error {
ejvr, err := NewExtJSONValueReader(bytes.NewReader(data), canonicalOnly)
if err != nil {
return err
}
return unmarshalFromReader(DecodeContext{Registry: defaultRegistry}, ejvr, val)
}
func unmarshalFromReader(dc DecodeContext, vr ValueReader, val interface{}) error {
dec := decPool.Get().(*Decoder)
defer decPool.Put(dec)
dec.Reset(vr)
dec.dc = dc
return dec.Decode(val)
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
)
var _ ValueReader = &valueReader{}
// ErrEOA is the error returned when the end of a BSON array has been reached.
var ErrEOA = errors.New("end of array")
// ErrEOD is the error returned when the end of a BSON document has been reached.
var ErrEOD = errors.New("end of document")
type vrState struct {
mode mode
vType Type
end int64
}
// valueReader is for reading BSON values.
type valueReader struct {
r *bufio.Reader
offset int64
stack []vrState
frame int64
}
// NewDocumentReader returns a ValueReader using b for the underlying BSON
// representation.
func NewDocumentReader(r io.Reader) ValueReader {
return newDocumentReader(r)
}
// newValueReader returns a ValueReader that starts in the Value mode instead of in top
// level document mode. This enables the creation of a ValueReader for a single BSON value.
func newValueReader(t Type, r io.Reader) ValueReader {
vr := newDocumentReader(r)
if len(vr.stack) == 0 {
vr.stack = make([]vrState, 1, 5)
}
vr.stack[0].mode = mValue
vr.stack[0].vType = t
return vr
}
func newDocumentReader(r io.Reader) *valueReader {
stack := make([]vrState, 1, 5)
stack[0] = vrState{
mode: mTopLevel,
}
return &valueReader{
r: bufio.NewReader(r),
stack: stack,
}
}
func (vr *valueReader) advanceFrame() {
if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack
length := len(vr.stack)
if length+1 >= cap(vr.stack) {
// double it
buf := make([]vrState, 2*cap(vr.stack)+1)
copy(buf, vr.stack)
vr.stack = buf
}
vr.stack = vr.stack[:length+1]
}
vr.frame++
// Clean the stack
vr.stack[vr.frame].mode = 0
vr.stack[vr.frame].vType = 0
vr.stack[vr.frame].end = 0
}
func (vr *valueReader) pushDocument() error {
vr.advanceFrame()
vr.stack[vr.frame].mode = mDocument
length, err := vr.readLength()
if err != nil {
return err
}
vr.stack[vr.frame].end = int64(length) + vr.offset - 4
return nil
}
func (vr *valueReader) pushArray() error {
vr.advanceFrame()
vr.stack[vr.frame].mode = mArray
length, err := vr.readLength()
if err != nil {
return err
}
vr.stack[vr.frame].end = int64(length) + vr.offset - 4
return nil
}
func (vr *valueReader) pushElement(t Type) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mElement
vr.stack[vr.frame].vType = t
}
func (vr *valueReader) pushValue(t Type) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mValue
vr.stack[vr.frame].vType = t
}
func (vr *valueReader) pushCodeWithScope() (int64, error) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mCodeWithScope
length, err := vr.readLength()
if err != nil {
return 0, err
}
vr.stack[vr.frame].end = int64(length) + vr.offset - 4
return int64(length), nil
}
func (vr *valueReader) pop() error {
var cnt int
switch vr.stack[vr.frame].mode {
case mElement, mValue:
cnt = 1
case mDocument, mArray, mCodeWithScope:
cnt = 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
for i := 0; i < cnt && vr.frame > 0; i++ {
if vr.offset < vr.stack[vr.frame].end {
_, err := vr.r.Discard(int(vr.stack[vr.frame].end - vr.offset))
if err != nil {
return err
}
}
vr.frame--
}
if vr.frame == 0 {
if vr.stack[0].end > vr.offset {
vr.stack[0].end -= vr.offset
} else {
vr.stack[0].end = 0
}
vr.offset = 0
}
return nil
}
func (vr *valueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vr.stack[vr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if vr.frame != 0 {
te.parent = vr.stack[vr.frame-1].mode
}
return te
}
func (vr *valueReader) typeError(t Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", vr.stack[vr.frame].vType, t)
}
func (vr *valueReader) invalidDocumentLengthError() error {
return fmt.Errorf("document is invalid, end byte is at %d, but null byte found at %d", vr.stack[vr.frame].end, vr.offset)
}
func (vr *valueReader) ensureElementValue(t Type, destination mode, callerName string) error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
if vr.stack[vr.frame].vType != t {
return vr.typeError(t)
}
default:
return vr.invalidTransitionErr(destination, callerName, []mode{mElement, mValue})
}
return nil
}
func (vr *valueReader) Type() Type {
return vr.stack[vr.frame].vType
}
func (vr *valueReader) appendNextElement(dst []byte) ([]byte, error) {
var length int32
var err error
switch vr.stack[vr.frame].vType {
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
length, err = vr.peekLength()
case TypeBinary:
length, err = vr.peekLength()
length += 4 + 1 // binary length + subtype byte
case TypeBoolean:
length = 1
case TypeDBPointer:
length, err = vr.peekLength()
length += 4 + 12 // string length + ObjectID length
case TypeDateTime, TypeDouble, TypeInt64, TypeTimestamp:
length = 8
case TypeDecimal128:
length = 16
case TypeInt32:
length = 4
case TypeJavaScript, TypeString, TypeSymbol:
length, err = vr.peekLength()
length += 4
case TypeMaxKey, TypeMinKey, TypeNull, TypeUndefined:
length = 0
case TypeObjectID:
length = 12
case TypeRegex:
for n := 0; n < 2; n++ { // Read two C strings.
str, err := vr.r.ReadBytes(0x00)
if err != nil {
return nil, err
}
dst = append(dst, str...)
vr.offset += int64(len(str))
}
return dst, nil
default:
return nil, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType)
}
if err != nil {
return nil, err
}
buf := make([]byte, length)
_, err = io.ReadFull(vr.r, buf)
if err != nil {
return nil, err
}
dst = append(dst, buf...)
vr.offset += int64(len(buf))
return dst, err
}
func (vr *valueReader) readValueBytes(dst []byte) (Type, []byte, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
length, err := vr.peekLength()
if err != nil {
return Type(0), nil, err
}
dst, err = vr.appendBytes(dst, length)
if err != nil {
return Type(0), nil, err
}
return Type(0), dst, nil
case mElement, mValue:
dst, err := vr.appendNextElement(dst)
if err != nil {
return Type(0), dst, err
}
t := vr.stack[vr.frame].vType
err = vr.pop()
if err != nil {
return Type(0), nil, err
}
return t, dst, nil
default:
return Type(0), nil, vr.invalidTransitionErr(0, "ReadValueBytes", []mode{mElement, mValue})
}
}
func (vr *valueReader) Skip() error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
default:
return vr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
_, err := vr.appendNextElement(nil)
if err != nil {
return err
}
return vr.pop()
}
func (vr *valueReader) ReadArray() (ArrayReader, error) {
if err := vr.ensureElementValue(TypeArray, mArray, "ReadArray"); err != nil {
return nil, err
}
err := vr.pushArray()
if err != nil {
return nil, err
}
return vr, nil
}
func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := vr.ensureElementValue(TypeBinary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
length, err := vr.readLength()
if err != nil {
return nil, 0, err
}
btype, err = vr.readByte()
if err != nil {
return nil, 0, err
}
// Check length in case it is an old binary without a length.
if btype == 0x02 && length > 4 {
length, err = vr.readLength()
if err != nil {
return nil, 0, err
}
}
b = make([]byte, length)
err = vr.read(b)
if err != nil {
return nil, 0, err
}
if err := vr.pop(); err != nil {
return nil, 0, err
}
return b, btype, nil
}
func (vr *valueReader) ReadBoolean() (bool, error) {
if err := vr.ensureElementValue(TypeBoolean, 0, "ReadBoolean"); err != nil {
return false, err
}
b, err := vr.readByte()
if err != nil {
return false, err
}
if b > 1 {
return false, fmt.Errorf("invalid byte for boolean, %b", b)
}
if err := vr.pop(); err != nil {
return false, err
}
return b == 1, nil
}
func (vr *valueReader) ReadDocument() (DocumentReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
length, err := vr.readLength()
if err != nil {
return nil, err
}
if length <= 4 {
return nil, fmt.Errorf("invalid string length: %d", length)
}
vr.stack[vr.frame].end = int64(length) + vr.offset - 4
return vr, nil
case mElement, mValue:
if vr.stack[vr.frame].vType != TypeEmbeddedDocument {
return nil, vr.typeError(TypeEmbeddedDocument)
}
default:
return nil, vr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
err := vr.pushDocument()
if err != nil {
return nil, err
}
return vr, nil
}
func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err := vr.ensureElementValue(TypeCodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
totalLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
strLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
if strLength <= 0 {
return "", nil, fmt.Errorf("invalid string length: %d", strLength)
}
strBytes := make([]byte, strLength)
err = vr.read(strBytes)
if err != nil {
return "", nil, err
}
code = string(strBytes[:len(strBytes)-1])
size, err := vr.pushCodeWithScope()
if err != nil {
return "", nil, err
}
// The total length should equal:
// 4 (total length) + strLength + 4 (the length of str itself) + (document length)
componentsLength := int64(4+strLength+4) + size
if int64(totalLength) != componentsLength {
return "", nil, fmt.Errorf(
"length of CodeWithScope does not match lengths of components; total: %d; components: %d",
totalLength, componentsLength,
)
}
return code, vr, nil
}
func (vr *valueReader) ReadDBPointer() (ns string, oid ObjectID, err error) {
if err := vr.ensureElementValue(TypeDBPointer, 0, "ReadDBPointer"); err != nil {
return "", oid, err
}
ns, err = vr.readString()
if err != nil {
return "", oid, err
}
err = vr.read(oid[:])
if err != nil {
return "", ObjectID{}, err
}
if err := vr.pop(); err != nil {
return "", ObjectID{}, err
}
return ns, oid, nil
}
func (vr *valueReader) ReadDateTime() (int64, error) {
if err := vr.ensureElementValue(TypeDateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
i, err := vr.readi64()
if err != nil {
return 0, err
}
if err := vr.pop(); err != nil {
return 0, err
}
return i, nil
}
func (vr *valueReader) ReadDecimal128() (Decimal128, error) {
if err := vr.ensureElementValue(TypeDecimal128, 0, "ReadDecimal128"); err != nil {
return Decimal128{}, err
}
var b [16]byte
err := vr.read(b[:])
if err != nil {
return Decimal128{}, err
}
l := binary.LittleEndian.Uint64(b[0:8])
h := binary.LittleEndian.Uint64(b[8:16])
if err := vr.pop(); err != nil {
return Decimal128{}, err
}
return NewDecimal128(h, l), nil
}
func (vr *valueReader) ReadDouble() (float64, error) {
if err := vr.ensureElementValue(TypeDouble, 0, "ReadDouble"); err != nil {
return 0, err
}
u, err := vr.readu64()
if err != nil {
return 0, err
}
if err := vr.pop(); err != nil {
return 0, err
}
return math.Float64frombits(u), nil
}
func (vr *valueReader) ReadInt32() (int32, error) {
if err := vr.ensureElementValue(TypeInt32, 0, "ReadInt32"); err != nil {
return 0, err
}
if err := vr.pop(); err != nil {
return 0, err
}
return vr.readi32()
}
func (vr *valueReader) ReadInt64() (int64, error) {
if err := vr.ensureElementValue(TypeInt64, 0, "ReadInt64"); err != nil {
return 0, err
}
if err := vr.pop(); err != nil {
return 0, err
}
return vr.readi64()
}
func (vr *valueReader) ReadJavascript() (string, error) {
if err := vr.ensureElementValue(TypeJavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
if err := vr.pop(); err != nil {
return "", err
}
return vr.readString()
}
func (vr *valueReader) ReadMaxKey() error {
if err := vr.ensureElementValue(TypeMaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
return vr.pop()
}
func (vr *valueReader) ReadMinKey() error {
if err := vr.ensureElementValue(TypeMinKey, 0, "ReadMinKey"); err != nil {
return err
}
return vr.pop()
}
func (vr *valueReader) ReadNull() error {
if err := vr.ensureElementValue(TypeNull, 0, "ReadNull"); err != nil {
return err
}
return vr.pop()
}
func (vr *valueReader) ReadObjectID() (ObjectID, error) {
if err := vr.ensureElementValue(TypeObjectID, 0, "ReadObjectID"); err != nil {
return ObjectID{}, err
}
var oid ObjectID
err := vr.read(oid[:])
if err != nil {
return ObjectID{}, err
}
if err := vr.pop(); err != nil {
return ObjectID{}, err
}
return oid, nil
}
func (vr *valueReader) ReadRegex() (string, string, error) {
if err := vr.ensureElementValue(TypeRegex, 0, "ReadRegex"); err != nil {
return "", "", err
}
pattern, err := vr.readCString()
if err != nil {
return "", "", err
}
options, err := vr.readCString()
if err != nil {
return "", "", err
}
if err := vr.pop(); err != nil {
return "", "", err
}
return pattern, options, nil
}
func (vr *valueReader) ReadString() (string, error) {
if err := vr.ensureElementValue(TypeString, 0, "ReadString"); err != nil {
return "", err
}
if err := vr.pop(); err != nil {
return "", err
}
return vr.readString()
}
func (vr *valueReader) ReadSymbol() (string, error) {
if err := vr.ensureElementValue(TypeSymbol, 0, "ReadSymbol"); err != nil {
return "", err
}
if err := vr.pop(); err != nil {
return "", err
}
return vr.readString()
}
func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err := vr.ensureElementValue(TypeTimestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
i, err = vr.readu32()
if err != nil {
return 0, 0, err
}
t, err = vr.readu32()
if err != nil {
return 0, 0, err
}
if err := vr.pop(); err != nil {
return 0, 0, err
}
return t, i, nil
}
func (vr *valueReader) ReadUndefined() error {
if err := vr.ensureElementValue(TypeUndefined, 0, "ReadUndefined"); err != nil {
return err
}
return vr.pop()
}
func (vr *valueReader) ReadElement() (string, ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, vr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
t, err := vr.readByte()
if err != nil {
return "", nil, err
}
if t == 0 {
if vr.offset != vr.stack[vr.frame].end {
return "", nil, vr.invalidDocumentLengthError()
}
_ = vr.pop() // Ignore the error because the call here never reads from the underlying reader.
return "", nil, ErrEOD
}
name, err := vr.readCString()
if err != nil {
return "", nil, err
}
vr.pushElement(Type(t))
return name, vr, nil
}
func (vr *valueReader) ReadValue() (ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mArray:
default:
return nil, vr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := vr.readByte()
if err != nil {
return nil, err
}
if t == 0 {
if vr.offset != vr.stack[vr.frame].end {
return nil, vr.invalidDocumentLengthError()
}
_ = vr.pop() // Ignore the error because the call here never reads from the underlying reader.
return nil, ErrEOA
}
if _, err := vr.readCString(); err != nil {
return nil, err
}
vr.pushValue(Type(t))
return vr, nil
}
func (vr *valueReader) read(p []byte) error {
n, err := io.ReadFull(vr.r, p)
if err != nil {
return err
}
vr.offset += int64(n)
return nil
}
func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) {
buf := make([]byte, length)
err := vr.read(buf)
if err != nil {
return nil, err
}
return append(dst, buf...), nil
}
func (vr *valueReader) readByte() (byte, error) {
b, err := vr.r.ReadByte()
if err != nil {
return 0x0, err
}
vr.offset++
return b, nil
}
func (vr *valueReader) readCString() (string, error) {
str, err := vr.r.ReadString(0x00)
if err != nil {
return "", err
}
l := len(str)
vr.offset += int64(l)
return str[:l-1], nil
}
func (vr *valueReader) readString() (string, error) {
length, err := vr.readLength()
if err != nil {
return "", err
}
if length <= 0 {
return "", fmt.Errorf("invalid string length: %d", length)
}
buf := make([]byte, length)
err = vr.read(buf)
if err != nil {
return "", err
}
if buf[length-1] != 0x00 {
return "", fmt.Errorf("string does not end with null byte, but with %v", buf[length-1])
}
return string(buf[:length-1]), nil
}
func (vr *valueReader) peekLength() (int32, error) {
buf, err := vr.r.Peek(4)
if err != nil {
return 0, err
}
return int32(binary.LittleEndian.Uint32(buf)), nil
}
func (vr *valueReader) readLength() (int32, error) {
l, err := vr.readi32()
if err != nil {
return 0, err
}
if l < 0 {
return 0, fmt.Errorf("invalid negative length: %d", l)
}
return l, nil
}
func (vr *valueReader) readi32() (int32, error) {
var buf [4]byte
err := vr.read(buf[:])
if err != nil {
return 0, err
}
return int32(binary.LittleEndian.Uint32(buf[:])), nil
}
func (vr *valueReader) readu32() (uint32, error) {
var buf [4]byte
err := vr.read(buf[:])
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint32(buf[:]), nil
}
func (vr *valueReader) readi64() (int64, error) {
var buf [8]byte
err := vr.read(buf[:])
if err != nil {
return 0, err
}
return int64(binary.LittleEndian.Uint64(buf[:])), nil
}
func (vr *valueReader) readu64() (uint64, error) {
var buf [8]byte
err := vr.read(buf[:])
if err != nil {
return 0, err
}
return binary.LittleEndian.Uint64(buf[:]), nil
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"io"
"math"
"strconv"
"strings"
"sync"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
)
var _ ValueWriter = &valueWriter{}
var vwPool = sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
}
func putValueWriter(vw *valueWriter) {
if vw != nil {
vw.w = nil // don't leak the writer
vwPool.Put(vw)
}
}
// This is here so that during testing we can change it and not require
// allocating a 4GB slice.
var maxSize = math.MaxInt32
type errMaxDocumentSizeExceeded struct {
size int64
}
func (mdse errMaxDocumentSizeExceeded) Error() string {
return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
}
type vwMode int
const (
_ vwMode = iota
vwTopLevel
vwDocument
vwArray
vwValue
vwElement
vwCodeWithScope
)
func (vm vwMode) String() string {
var str string
switch vm {
case vwTopLevel:
str = "TopLevel"
case vwDocument:
str = "DocumentMode"
case vwArray:
str = "ArrayMode"
case vwValue:
str = "ValueMode"
case vwElement:
str = "ElementMode"
case vwCodeWithScope:
str = "CodeWithScopeMode"
default:
str = "UnknownMode"
}
return str
}
type vwState struct {
mode mode
key string
arrkey int
start int32
}
type valueWriter struct {
w io.Writer
buf []byte
stack []vwState
frame int64
}
func (vw *valueWriter) advanceFrame() {
vw.frame++
if vw.frame >= int64(len(vw.stack)) {
vw.stack = append(vw.stack, vwState{})
}
}
func (vw *valueWriter) push(m mode) {
vw.advanceFrame()
// Clean the stack
vw.stack[vw.frame] = vwState{mode: m}
switch m {
case mDocument, mArray, mCodeWithScope:
vw.reserveLength() // WARN: this is not needed
}
}
func (vw *valueWriter) reserveLength() {
vw.stack[vw.frame].start = int32(len(vw.buf))
vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
}
func (vw *valueWriter) pop() {
switch vw.stack[vw.frame].mode {
case mElement, mValue:
vw.frame--
case mDocument, mArray, mCodeWithScope:
vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
// NewDocumentWriter creates a ValueWriter that writes BSON to w.
//
// This ValueWriter will only write entire documents to the io.Writer and it
// will buffer the document as it is built.
func NewDocumentWriter(w io.Writer) ValueWriter {
return newDocumentWriter(w)
}
func newDocumentWriter(w io.Writer) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.w = w
vw.stack = stack
return vw
}
func newValueWriterFromSlice(buf []byte) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.stack = stack
vw.buf = buf
return vw
}
func (vw *valueWriter) reset(buf []byte) {
if vw.stack == nil {
vw.stack = make([]vwState, 1, 5)
}
vw.stack = vw.stack[:1]
vw.stack[0] = vwState{mode: mTopLevel}
vw.buf = buf
vw.frame = 0
vw.w = nil
}
func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vw.stack[vw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if vw.frame != 0 {
te.parent = vw.stack[vw.frame-1].mode
}
return te
}
func (vw *valueWriter) writeElementHeader(t Type, destination mode, callerName string, addmodes ...mode) error {
frame := &vw.stack[vw.frame]
switch frame.mode {
case mElement:
key := frame.key
if !isValidCString(key) {
return errors.New("BSON element key cannot contain null bytes")
}
vw.appendHeader(t, key)
case mValue:
vw.appendIntHeader(t, frame.arrkey)
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return vw.invalidTransitionError(destination, callerName, modes)
}
return nil
}
func (vw *valueWriter) writeValueBytes(t Type, b []byte) error {
if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
return err
}
vw.buf = append(vw.buf, b...)
vw.pop()
return nil
}
func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
if err := vw.writeElementHeader(TypeArray, mArray, "WriteArray"); err != nil {
return nil, err
}
vw.push(mArray)
return vw, nil
}
func (vw *valueWriter) WriteBinary(b []byte) error {
return vw.WriteBinaryWithSubtype(b, 0x00)
}
func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := vw.writeElementHeader(TypeBinary, mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteBoolean(b bool) error {
if err := vw.writeElementHeader(TypeBoolean, mode(0), "WriteBoolean"); err != nil {
return err
}
vw.buf = bsoncore.AppendBoolean(vw.buf, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := vw.writeElementHeader(TypeCodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
// CodeWithScope is a different than other types because we need an extra
// frame on the stack. In the EndDocument code, we write the document
// length, pop, write the code with scope length, and pop. To simplify the
// pop code, we push a spacer frame that we'll always jump over.
vw.push(mCodeWithScope)
vw.buf = bsoncore.AppendString(vw.buf, code)
vw.push(mSpacer)
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteDBPointer(ns string, oid ObjectID) error {
if err := vw.writeElementHeader(TypeDBPointer, mode(0), "WriteDBPointer"); err != nil {
return err
}
vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDateTime(dt int64) error {
if err := vw.writeElementHeader(TypeDateTime, mode(0), "WriteDateTime"); err != nil {
return err
}
vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDecimal128(d128 Decimal128) error {
if err := vw.writeElementHeader(TypeDecimal128, mode(0), "WriteDecimal128"); err != nil {
return err
}
h, l := d128.GetBytes()
vw.buf = bsoncore.AppendDecimal128(vw.buf, h, l)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDouble(f float64) error {
if err := vw.writeElementHeader(TypeDouble, mode(0), "WriteDouble"); err != nil {
return err
}
vw.buf = bsoncore.AppendDouble(vw.buf, f)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt32(i32 int32) error {
if err := vw.writeElementHeader(TypeInt32, mode(0), "WriteInt32"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt32(vw.buf, i32)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt64(i64 int64) error {
if err := vw.writeElementHeader(TypeInt64, mode(0), "WriteInt64"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt64(vw.buf, i64)
vw.pop()
return nil
}
func (vw *valueWriter) WriteJavascript(code string) error {
if err := vw.writeElementHeader(TypeJavaScript, mode(0), "WriteJavascript"); err != nil {
return err
}
vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
vw.pop()
return nil
}
func (vw *valueWriter) WriteMaxKey() error {
if err := vw.writeElementHeader(TypeMaxKey, mode(0), "WriteMaxKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteMinKey() error {
if err := vw.writeElementHeader(TypeMinKey, mode(0), "WriteMinKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteNull() error {
if err := vw.writeElementHeader(TypeNull, mode(0), "WriteNull"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteObjectID(oid ObjectID) error {
if err := vw.writeElementHeader(TypeObjectID, mode(0), "WriteObjectID"); err != nil {
return err
}
vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteRegex(pattern string, options string) error {
if !isValidCString(pattern) || !isValidCString(options) {
return errors.New("BSON regex values cannot contain null bytes")
}
if err := vw.writeElementHeader(TypeRegex, mode(0), "WriteRegex"); err != nil {
return err
}
vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
vw.pop()
return nil
}
func (vw *valueWriter) WriteString(s string) error {
if err := vw.writeElementHeader(TypeString, mode(0), "WriteString"); err != nil {
return err
}
vw.buf = bsoncore.AppendString(vw.buf, s)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
if vw.stack[vw.frame].mode == mTopLevel {
vw.reserveLength()
return vw, nil
}
if err := vw.writeElementHeader(TypeEmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteSymbol(symbol string) error {
if err := vw.writeElementHeader(TypeSymbol, mode(0), "WriteSymbol"); err != nil {
return err
}
vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
vw.pop()
return nil
}
func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := vw.writeElementHeader(TypeTimestamp, mode(0), "WriteTimestamp"); err != nil {
return err
}
vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
vw.pop()
return nil
}
func (vw *valueWriter) WriteUndefined() error {
if err := vw.writeElementHeader(TypeUndefined, mode(0), "WriteUndefined"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
}
vw.push(mElement)
vw.stack[vw.frame].key = key
return vw, nil
}
func (vw *valueWriter) WriteDocumentEnd() error {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
if vw.stack[vw.frame].mode == mTopLevel {
if err = vw.Flush(); err != nil {
return err
}
}
vw.pop()
if vw.stack[vw.frame].mode == mCodeWithScope {
// We ignore the error here because of the guarantee of writeLength.
// See the docs for writeLength for more info.
_ = vw.writeLength()
vw.pop()
}
return nil
}
func (vw *valueWriter) Flush() error {
if vw.w == nil {
return nil
}
if _, err := vw.w.Write(vw.buf); err != nil {
return err
}
// reset buffer
vw.buf = vw.buf[:0]
return nil
}
func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
if vw.stack[vw.frame].mode != mArray {
return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
}
arrkey := vw.stack[vw.frame].arrkey
vw.stack[vw.frame].arrkey++
vw.push(mValue)
vw.stack[vw.frame].arrkey = arrkey
return vw, nil
}
func (vw *valueWriter) WriteArrayEnd() error {
if vw.stack[vw.frame].mode != mArray {
return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
vw.pop()
return nil
}
// NOTE: We assume that if we call writeLength more than once the same function
// within the same function without altering the vw.buf that this method will
// not return an error. If this changes ensure that the following methods are
// updated:
//
// - WriteDocumentEnd
func (vw *valueWriter) writeLength() error {
length := len(vw.buf)
if length > maxSize {
return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
}
frame := &vw.stack[vw.frame]
length -= int(frame.start)
start := frame.start
_ = vw.buf[start+3] // BCE
vw.buf[start+0] = byte(length)
vw.buf[start+1] = byte(length >> 8)
vw.buf[start+2] = byte(length >> 16)
vw.buf[start+3] = byte(length >> 24)
return nil
}
func isValidCString(cs string) bool {
// Disallow the zero byte in a cstring because the zero byte is used as the
// terminating character.
//
// It's safe to check bytes instead of runes because all multibyte UTF-8
// code points start with (binary) 11xxxxxx or 10xxxxxx, so 00000000 (i.e.
// 0) will never be part of a multibyte UTF-8 code point. This logic is the
// same as the "r < utf8.RuneSelf" case in strings.IndexRune but can be
// inlined.
//
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/strings/strings.go;l=127
return strings.IndexByte(cs, 0) == -1
}
// appendHeader is the same as bsoncore.AppendHeader but does not check if the
// key is a valid C string since the caller has already checked for that.
//
// The caller of this function must check if key is a valid C string.
func (vw *valueWriter) appendHeader(t Type, key string) {
vw.buf = bsoncore.AppendType(vw.buf, bsoncore.Type(t))
vw.buf = append(vw.buf, key...)
vw.buf = append(vw.buf, 0x00)
}
func (vw *valueWriter) appendIntHeader(t Type, key int) {
vw.buf = bsoncore.AppendType(vw.buf, bsoncore.Type(t))
vw.buf = strconv.AppendInt(vw.buf, int64(key), 10)
vw.buf = append(vw.buf, 0x00)
}
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/binary"
"errors"
"fmt"
"math"
)
// BSON binary vector types as described in https://bsonspec.org/spec.html.
const (
Int8Vector byte = 0x03
Float32Vector byte = 0x27
PackedBitVector byte = 0x10
)
// These are vector conversion errors.
var (
errInsufficientVectorData = errors.New("insufficient data")
errNonZeroVectorPadding = errors.New("padding must be 0")
errVectorPaddingTooLarge = errors.New("padding cannot be larger than 7")
)
type vectorTypeError struct {
Method string
Type byte
}
// Error implements the error interface.
func (vte vectorTypeError) Error() string {
t := "invalid"
switch vte.Type {
case Int8Vector:
t = "int8"
case Float32Vector:
t = "float32"
case PackedBitVector:
t = "packed bit"
}
return fmt.Sprintf("cannot call %s, on a type %s vector", vte.Method, t)
}
// Vector represents a densely packed array of numbers / bits.
type Vector struct {
dType byte
int8Data []int8
float32Data []float32
bitData []byte
bitPadding uint8
}
// Type returns the vector type.
func (v Vector) Type() byte {
return v.dType
}
// Int8 returns the int8 slice hold by the vector.
// It panics if v is not an int8 vector.
func (v Vector) Int8() []int8 {
d, ok := v.Int8OK()
if !ok {
panic(vectorTypeError{"bson.Vector.Int8", v.dType})
}
return d
}
// Int8OK is the same as Int8, but returns a boolean instead of panicking.
func (v Vector) Int8OK() ([]int8, bool) {
if v.dType != Int8Vector {
return nil, false
}
return v.int8Data, true
}
// Float32 returns the float32 slice hold by the vector.
// It panics if v is not a float32 vector.
func (v Vector) Float32() []float32 {
d, ok := v.Float32OK()
if !ok {
panic(vectorTypeError{"bson.Vector.Float32", v.dType})
}
return d
}
// Float32OK is the same as Float32, but returns a boolean instead of panicking.
func (v Vector) Float32OK() ([]float32, bool) {
if v.dType != Float32Vector {
return nil, false
}
return v.float32Data, true
}
// PackedBit returns the byte slice representing the binary quantized (packed bit) vector and the byte padding, which
// is the number of bits in the final byte that are to be ignored.
// It panics if v is not a packed bit vector.
func (v Vector) PackedBit() ([]byte, uint8) {
d, p, ok := v.PackedBitOK()
if !ok {
panic(vectorTypeError{"bson.Vector.PackedBit", v.dType})
}
return d, p
}
// PackedBitOK is the same as PackedBit, but returns a boolean instead of panicking.
func (v Vector) PackedBitOK() ([]byte, uint8, bool) {
if v.dType != PackedBitVector {
return nil, 0, false
}
return v.bitData, v.bitPadding, true
}
// Binary returns the BSON Binary representation of the Vector.
func (v Vector) Binary() Binary {
switch v.Type() {
case Int8Vector:
return binaryFromInt8Vector(v.Int8())
case Float32Vector:
return binaryFromFloat32Vector(v.Float32())
case PackedBitVector:
return binaryFromBitVector(v.PackedBit())
default:
panic(fmt.Sprintf("invalid Vector data type: %d", v.dType))
}
}
func binaryFromInt8Vector(v []int8) Binary {
data := make([]byte, len(v)+2)
data[0] = Int8Vector
data[1] = 0
for i, e := range v {
data[i+2] = byte(e)
}
return Binary{
Subtype: TypeBinaryVector,
Data: data,
}
}
func binaryFromFloat32Vector(v []float32) Binary {
data := make([]byte, 2, len(v)*4+2)
data[0] = Float32Vector
data[1] = 0
var a [4]byte
for _, e := range v {
binary.LittleEndian.PutUint32(a[:], math.Float32bits(e))
data = append(data, a[:]...)
}
return Binary{
Subtype: TypeBinaryVector,
Data: data,
}
}
func binaryFromBitVector(bits []byte, padding uint8) Binary {
data := make([]byte, len(bits)+2)
data[0] = PackedBitVector
data[1] = padding
copy(data[2:], bits)
return Binary{
Subtype: TypeBinaryVector,
Data: data,
}
}
// NewVector constructs a Vector from a slice of int8 or float32.
func NewVector[T int8 | float32](data []T) Vector {
var v Vector
switch a := any(data).(type) {
case []int8:
v.dType = Int8Vector
v.int8Data = make([]int8, len(data))
copy(v.int8Data, a)
case []float32:
v.dType = Float32Vector
v.float32Data = make([]float32, len(data))
copy(v.float32Data, a)
default:
panic(fmt.Errorf("unsupported type %T", data))
}
return v
}
// NewPackedBitVector constructs a Vector from a byte slice and a value of byte padding.
func NewPackedBitVector(bits []byte, padding uint8) (Vector, error) {
var v Vector
if padding > 7 {
return v, errVectorPaddingTooLarge
}
if padding > 0 && len(bits) == 0 {
return v, errNonZeroVectorPadding
}
v.dType = PackedBitVector
v.bitData = make([]byte, len(bits))
copy(v.bitData, bits)
v.bitPadding = padding
return v, nil
}
// NewVectorFromBinary unpacks a BSON Binary into a Vector.
func NewVectorFromBinary(b Binary) (Vector, error) {
var v Vector
if b.Subtype != TypeBinaryVector {
return v, errors.New("not a vector")
}
if len(b.Data) < 2 {
return v, errInsufficientVectorData
}
switch t := b.Data[0]; t {
case Int8Vector:
return newInt8Vector(b.Data[1:])
case Float32Vector:
return newFloat32Vector(b.Data[1:])
case PackedBitVector:
return newBitVector(b.Data[1:])
default:
return v, fmt.Errorf("invalid Vector data type: %d", t)
}
}
func newInt8Vector(b []byte) (Vector, error) {
var v Vector
if len(b) == 0 {
return v, errInsufficientVectorData
}
if padding := b[0]; padding > 0 {
return v, errNonZeroVectorPadding
}
s := make([]int8, 0, len(b)-1)
for i := 1; i < len(b); i++ {
s = append(s, int8(b[i]))
}
return NewVector(s), nil
}
func newFloat32Vector(b []byte) (Vector, error) {
var v Vector
if len(b) == 0 {
return v, errInsufficientVectorData
}
if padding := b[0]; padding > 0 {
return v, errNonZeroVectorPadding
}
l := (len(b) - 1) / 4
if l*4 != len(b)-1 {
return v, errInsufficientVectorData
}
s := make([]float32, 0, l)
for i := 1; i < len(b); i += 4 {
s = append(s, math.Float32frombits(binary.LittleEndian.Uint32(b[i:i+4])))
}
return NewVector(s), nil
}
func newBitVector(b []byte) (Vector, error) {
if len(b) == 0 {
return Vector{}, errInsufficientVectorData
}
return NewPackedBitVector(b[1:], b[0])
}
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
// ArrayWriter is the interface used to create a BSON or BSON adjacent array.
// Callers must ensure they call WriteArrayEnd when they have finished creating
// the array.
type ArrayWriter interface {
WriteArrayElement() (ValueWriter, error)
WriteArrayEnd() error
}
// DocumentWriter is the interface used to create a BSON or BSON adjacent
// document. Callers must ensure they call WriteDocumentEnd when they have
// finished creating the document.
type DocumentWriter interface {
WriteDocumentElement(string) (ValueWriter, error)
WriteDocumentEnd() error
}
// ValueWriter is the interface used to write BSON values. Implementations of
// this interface handle creating BSON or BSON adjacent representations of the
// values.
type ValueWriter interface {
WriteArray() (ArrayWriter, error)
WriteBinary(b []byte) error
WriteBinaryWithSubtype(b []byte, btype byte) error
WriteBoolean(bool) error
WriteCodeWithScope(code string) (DocumentWriter, error)
WriteDBPointer(ns string, oid ObjectID) error
WriteDateTime(dt int64) error
WriteDecimal128(Decimal128) error
WriteDouble(float64) error
WriteInt32(int32) error
WriteInt64(int64) error
WriteJavascript(code string) error
WriteMaxKey() error
WriteMinKey() error
WriteNull() error
WriteObjectID(ObjectID) error
WriteRegex(pattern, options string) error
WriteString(string) error
WriteDocument() (DocumentWriter, error)
WriteSymbol(symbol string) error
WriteTimestamp(t, i uint32) error
WriteUndefined() error
}
// sliceWriter allows a pointer to a slice of bytes to be used as an io.Writer.
type sliceWriter []byte
// Write writes the bytes to the underlying slice.
func (sw *sliceWriter) Write(p []byte) (int, error) {
written := len(p)
*sw = append(*sw, p...)
return written, nil
}