Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/fory/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func (s byteArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType
func (s byteArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
length := buf.ReadLength(err)
length := buf.ReadCollectionLength(err)
if ctx.HasError() {
return
}
Expand Down
26 changes: 13 additions & 13 deletions go/fory/array_primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (s boolArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType
func (s boolArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
length := buf.ReadLength(err)
length := buf.ReadBinaryLength(err)
if ctx.HasError() {
return
}
Expand Down Expand Up @@ -131,7 +131,7 @@ func (s int8ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType
func (s int8ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
length := buf.ReadLength(err)
length := buf.ReadBinaryLength(err)
if ctx.HasError() {
return
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func (s int16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp
func (s int16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
size := buf.ReadLength(err)
size := buf.ReadBinaryLength(err)
length := size / 2
if ctx.HasError() {
return
Expand Down Expand Up @@ -269,7 +269,7 @@ func (s int32ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp
func (s int32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
size := buf.ReadLength(err)
size := buf.ReadBinaryLength(err)
length := size / 4
if ctx.HasError() {
return
Expand Down Expand Up @@ -341,7 +341,7 @@ func (s int64ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp
func (s int64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
size := buf.ReadLength(err)
size := buf.ReadBinaryLength(err)
length := size / 8
if ctx.HasError() {
return
Expand Down Expand Up @@ -413,7 +413,7 @@ func (s float32ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeT
func (s float32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
size := buf.ReadLength(err)
size := buf.ReadBinaryLength(err)
length := size / 4
if ctx.HasError() {
return
Expand Down Expand Up @@ -485,7 +485,7 @@ func (s float64ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeT
func (s float64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
size := buf.ReadLength(err)
size := buf.ReadBinaryLength(err)
length := size / 8
if ctx.HasError() {
return
Expand Down Expand Up @@ -556,7 +556,7 @@ func (s uint8ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp
func (s uint8ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
length := buf.ReadLength(err)
length := buf.ReadBinaryLength(err)
if ctx.HasError() {
return
}
Expand Down Expand Up @@ -623,7 +623,7 @@ func (s uint16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTy
func (s uint16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
size := buf.ReadLength(err)
size := buf.ReadBinaryLength(err)
length := size / 2
if ctx.HasError() {
return
Expand Down Expand Up @@ -694,7 +694,7 @@ func (s uint32ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTy
func (s uint32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
size := buf.ReadLength(err)
size := buf.ReadBinaryLength(err)
length := size / 4
if ctx.HasError() {
return
Expand Down Expand Up @@ -764,7 +764,7 @@ func (s uint64ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTy
func (s uint64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
err := ctx.Err()
size := buf.ReadLength(err)
size := buf.ReadBinaryLength(err)
length := size / 8
if ctx.HasError() {
return
Expand Down Expand Up @@ -838,7 +838,7 @@ func (s float16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeT
func (s float16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
size := buf.ReadLength(ctxErr)
size := buf.ReadBinaryLength(ctxErr)
length := size / 2
if ctx.HasError() {
return
Expand Down Expand Up @@ -912,7 +912,7 @@ func (s bfloat16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, write
func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
size := buf.ReadLength(ctxErr)
size := buf.ReadBinaryLength(ctxErr)
length := size / 2
if ctx.HasError() {
return
Expand Down
40 changes: 33 additions & 7 deletions go/fory/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ import (
)

type ByteBuffer struct {
data []byte // Most accessed field first for cache locality
writerIndex int
readerIndex int
reader io.Reader
bufferSize int
data []byte // Most accessed field first for cache locality
writerIndex int
readerIndex int
reader io.Reader
bufferSize int
maxCollectionSize int
maxBinarySize int
}

func NewByteBuffer(data []byte) *ByteBuffer {
Expand Down Expand Up @@ -196,8 +198,32 @@ func (b *ByteBuffer) WriteLength(value int) {
b.WriteVarUint32(uint32(value))
}

func (b *ByteBuffer) ReadLength(err *Error) int {
return int(b.ReadVarUint32(err))
func (b *ByteBuffer) ReadCollectionLength(err *Error) int {
length := int(b.ReadVarUint32(err))
if err != nil && err.HasError() {
return 0
}
if b.maxCollectionSize > 0 && length > b.maxCollectionSize {
if err != nil {
*err = MaxCollectionSizeExceededError(length, b.maxCollectionSize)
}
return 0
}
return length
}

func (b *ByteBuffer) ReadBinaryLength(err *Error) int {
length := int(b.ReadVarUint32(err))
if err != nil && err.HasError() {
return 0
}
if b.maxBinarySize > 0 && length > b.maxBinarySize {
if err != nil {
*err = MaxBinarySizeExceededError(length, b.maxBinarySize)
}
return 0
}
return length
}

func (b *ByteBuffer) WriteUint64(value uint64) {
Expand Down
20 changes: 10 additions & 10 deletions go/fory/codegen/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error {
fmt.Fprintf(buf, "\t\tisXlang := ctx.TypeResolver().IsXlang()\n")
fmt.Fprintf(buf, "\t\tif isXlang {\n")
fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read directly without null flag\n")
fmt.Fprintf(buf, "\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n")
fmt.Fprintf(buf, "\t\t\tsliceLen := buf.ReadCollectionLength(err)\n")
fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, 0)\n", fieldAccess)
fmt.Fprintf(buf, "\t\t\t} else {\n")
Expand All @@ -187,7 +187,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error {
fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") // NullFlag
fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess)
fmt.Fprintf(buf, "\t\t\t} else {\n")
fmt.Fprintf(buf, "\t\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n")
fmt.Fprintf(buf, "\t\t\t\tsliceLen := buf.ReadCollectionLength(err)\n")
fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, 0)\n", fieldAccess)
fmt.Fprintf(buf, "\t\t\t\t} else {\n")
Expand Down Expand Up @@ -517,7 +517,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc
fmt.Fprintf(buf, "\t\tisXlang := ctx.TypeResolver().IsXlang()\n")
fmt.Fprintf(buf, "\t\tif isXlang {\n")
fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read directly without null flag\n")
fmt.Fprintf(buf, "\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n")
fmt.Fprintf(buf, "\t\t\tsliceLen := buf.ReadCollectionLength(err)\n")
fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String())
fmt.Fprintf(buf, "\t\t\t} else {\n")
Expand All @@ -532,7 +532,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc
fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") // NullFlag
fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess)
fmt.Fprintf(buf, "\t\t\t} else {\n")
fmt.Fprintf(buf, "\t\t\t\tsliceLen := int(buf.ReadVarUint32(err))\n")
fmt.Fprintf(buf, "\t\t\t\tsliceLen := buf.ReadCollectionLength(err)\n")
fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String())
fmt.Fprintf(buf, "\t\t\t\t} else {\n")
Expand All @@ -555,7 +555,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi
unwrappedElem := types.Unalias(elemType)
if iface, ok := unwrappedElem.(*types.Interface); ok && iface.Empty() {
fmt.Fprintf(buf, "%s// Dynamic slice []any handling - no null flag\n", indent)
fmt.Fprintf(buf, "%ssliceLen := int(buf.ReadVarUint32(err))\n", indent)
fmt.Fprintf(buf, "%ssliceLen := buf.ReadCollectionLength(err)\n", indent)
fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent)
fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent, fieldAccess)
fmt.Fprintf(buf, "%s} else {\n", indent)
Expand All @@ -573,7 +573,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi
}

elemIsReferencable := isReferencableType(elemType)
fmt.Fprintf(buf, "%ssliceLen := int(buf.ReadVarUint32(err))\n", indent)
fmt.Fprintf(buf, "%ssliceLen := buf.ReadCollectionLength(err)\n", indent)
fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent)
fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess, sliceType.String())
fmt.Fprintf(buf, "%s} else {\n", indent)
Expand Down Expand Up @@ -703,7 +703,7 @@ func writePrimitiveSliceReadCall(buf *bytes.Buffer, basic *types.Basic, fieldAcc
case types.Int8:
fmt.Fprintf(buf, "%s%s = fory.ReadInt8Slice(buf, err)\n", indent, fieldAccess)
case types.Uint8:
fmt.Fprintf(buf, "%ssizeBytes := buf.ReadLength(err)\n", indent)
fmt.Fprintf(buf, "%ssizeBytes := buf.ReadBinaryLength(err)\n", indent)
fmt.Fprintf(buf, "%s%s = make([]uint8, sizeBytes)\n", indent, fieldAccess)
fmt.Fprintf(buf, "%sif sizeBytes > 0 {\n", indent)
fmt.Fprintf(buf, "%s\traw := buf.ReadBinary(sizeBytes, err)\n", indent)
Expand Down Expand Up @@ -925,7 +925,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st
fmt.Fprintf(buf, "\t\tisXlang := ctx.TypeResolver().IsXlang()\n")
fmt.Fprintf(buf, "\t\tif isXlang {\n")
fmt.Fprintf(buf, "\t\t\t// xlang mode: maps are not nullable, read directly without null flag\n")
fmt.Fprintf(buf, "\t\t\tmapLen := int(buf.ReadVarUint32(err))\n")
fmt.Fprintf(buf, "\t\t\tmapLen := buf.ReadCollectionLength(err)\n")
fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String())
fmt.Fprintf(buf, "\t\t\t} else {\n")
Expand All @@ -940,7 +940,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st
fmt.Fprintf(buf, "\t\t\tif nullFlag == -3 {\n") // NullFlag
fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess)
fmt.Fprintf(buf, "\t\t\t} else {\n")
fmt.Fprintf(buf, "\t\t\t\tmapLen := int(buf.ReadVarUint32(err))\n")
fmt.Fprintf(buf, "\t\t\t\tmapLen := buf.ReadCollectionLength(err)\n")
fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n")
fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String())
fmt.Fprintf(buf, "\t\t\t\t} else {\n")
Expand Down Expand Up @@ -972,7 +972,7 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc
}

indent := "\t\t\t"
fmt.Fprintf(buf, "%smapLen := int(buf.ReadVarUint32(err))\n", indent)
fmt.Fprintf(buf, "%smapLen := buf.ReadCollectionLength(err)\n", indent)
fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent)
fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess, mapType.String())
fmt.Fprintf(buf, "%s} else {\n", indent)
Expand Down
24 changes: 24 additions & 0 deletions go/fory/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ const (
ErrKindInvalidTag
// ErrKindInvalidUTF16String indicates malformed UTF-16 string data
ErrKindInvalidUTF16String
// ErrKindMaxCollectionSizeExceeded indicates max collection size exceeded
ErrKindMaxCollectionSizeExceeded
// ErrKindMaxBinarySizeExceeded indicates max binary size exceeded
ErrKindMaxBinarySizeExceeded
)

// Error is a lightweight error type optimized for hot path performance.
Expand Down Expand Up @@ -296,6 +300,26 @@ func InvalidUTF16StringError(byteCount int) Error {
})
}

// MaxCollectionSizeExceededError creates a max collection size exceeded error
//
//go:noinline
func MaxCollectionSizeExceededError(size, limit int) Error {
return panicIfEnabled(Error{
kind: ErrKindMaxCollectionSizeExceeded,
message: fmt.Sprintf("max collection size exceeded: size=%d, limit=%d", size, limit),
})
}

// MaxBinarySizeExceededError creates a max binary size exceeded error
//
//go:noinline
func MaxBinarySizeExceededError(size, limit int) Error {
return panicIfEnabled(Error{
kind: ErrKindMaxBinarySizeExceeded,
message: fmt.Sprintf("max binary size exceeded: size=%d, limit=%d", size, limit),
})
}

// WrapError wraps a standard error into a fory Error
//
//go:noinline
Expand Down
Loading
Loading