diff --git a/common/codec/interface.go b/common/codec/interface.go index 08899399664..45245c0244b 100644 --- a/common/codec/interface.go +++ b/common/codec/interface.go @@ -21,6 +21,7 @@ package codec import ( + "go.uber.org/thriftrw/protocol/stream" "go.uber.org/thriftrw/wire" "github.com/uber/cadence/common/types" @@ -37,6 +38,8 @@ type ( ThriftObject interface { FromWire(w wire.Value) error ToWire() (wire.Value, error) + Encode(stream.Writer) error + Decode(stream.Reader) error } ) diff --git a/common/codec/version0Thriftrw.go b/common/codec/version0Thriftrw.go index 4a0704f97b4..a2cfa401008 100644 --- a/common/codec/version0Thriftrw.go +++ b/common/codec/version0Thriftrw.go @@ -23,8 +23,7 @@ package codec import ( "bytes" - "go.uber.org/thriftrw/protocol" - "go.uber.org/thriftrw/wire" + "go.uber.org/thriftrw/protocol/binary" ) type ( @@ -52,33 +51,27 @@ func (t *ThriftRWEncoder) Encode(obj ThriftObject) ([]byte, error) { if err != nil { return nil, err } - val, err := obj.ToWire() - if err != nil { - return nil, err - } - err = protocol.Binary.Encode(val, &writer) - if err != nil { + + sw := binary.Default.Writer(&writer) + defer sw.Close() + if err := obj.Encode(sw); err != nil { return nil, err } return writer.Bytes(), nil } // Decode decode the object -func (t *ThriftRWEncoder) Decode(binary []byte, val ThriftObject) error { - if len(binary) < 1 { +func (t *ThriftRWEncoder) Decode(b []byte, val ThriftObject) error { + if len(b) < 1 { return MissingBinaryEncodingVersion } - version := binary[0] + version := b[0] if version != preambleVersion0 { return InvalidBinaryEncodingVersion } - reader := bytes.NewReader(binary[1:]) - wireVal, err := protocol.Binary.Decode(reader, wire.TStruct) - if err != nil { - return err - } - - return val.FromWire(wireVal) + reader := bytes.NewReader(b[1:]) + sr := binary.Default.Reader(reader) + return val.Decode(sr) } diff --git a/common/persistence/serialization/interfaces.go b/common/persistence/serialization/interfaces.go index d1afe6c91a6..460855a52d5 100644 --- a/common/persistence/serialization/interfaces.go +++ b/common/persistence/serialization/interfaces.go @@ -25,6 +25,7 @@ package serialization import ( "time" + "go.uber.org/thriftrw/protocol/stream" "go.uber.org/thriftrw/wire" "github.com/uber/cadence/.gen/go/sqlblobs" @@ -391,5 +392,7 @@ type ( thriftRWType interface { ToWire() (wire.Value, error) FromWire(w wire.Value) error + Encode(stream.Writer) error + Decode(stream.Reader) error } ) diff --git a/common/persistence/serialization/thrift_decoder.go b/common/persistence/serialization/thrift_decoder.go index b844493fc7e..0cc831af8c9 100644 --- a/common/persistence/serialization/thrift_decoder.go +++ b/common/persistence/serialization/thrift_decoder.go @@ -25,8 +25,7 @@ package serialization import ( "bytes" - "go.uber.org/thriftrw/protocol" - "go.uber.org/thriftrw/wire" + "go.uber.org/thriftrw/protocol/binary" "github.com/uber/cadence/.gen/go/sqlblobs" ) @@ -160,9 +159,7 @@ func (d *thriftDecoder) replicationTaskInfoFromBlob(data []byte) (*ReplicationTa } func thriftRWDecode(b []byte, result thriftRWType) error { - value, err := protocol.Binary.Decode(bytes.NewReader(b), wire.TStruct) - if err != nil { - return err - } - return result.FromWire(value) + buf := bytes.NewReader(b) + sr := binary.Default.Reader(buf) + return result.Decode(sr) } diff --git a/common/persistence/serialization/thrift_encoder.go b/common/persistence/serialization/thrift_encoder.go index f32d6128b2b..a397e925d4e 100644 --- a/common/persistence/serialization/thrift_encoder.go +++ b/common/persistence/serialization/thrift_encoder.go @@ -25,7 +25,7 @@ package serialization import ( "bytes" - "go.uber.org/thriftrw/protocol" + "go.uber.org/thriftrw/protocol/binary" "github.com/uber/cadence/common" ) @@ -101,12 +101,10 @@ func (e *thriftEncoder) encodingType() common.EncodingType { } func thriftRWEncode(t thriftRWType) ([]byte, error) { - value, err := t.ToWire() - if err != nil { - return nil, err - } var b bytes.Buffer - if err := protocol.Binary.Encode(value, &b); err != nil { + sw := binary.Default.Writer(&b) + defer sw.Close() + if err := t.Encode(sw); err != nil { return nil, err } return b.Bytes(), nil