Skip to content

Commit

Permalink
fix: lorry security problems (apecloud#5886)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuriwuyun authored Nov 22, 2023
1 parent 42fb966 commit 713a25c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 166 deletions.
7 changes: 4 additions & 3 deletions pkg/lorry/client/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ var _ = AfterSuite(func() {
})

func newTCPServer(port int) (net.Listener, int) {
var err error
for i := 0; i < 3; i++ {
tcpListener, _ = net.Listen("tcp", fmt.Sprintf(":%v", port))
if tcpListener != nil {
tcpListener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port))
if err == nil {
break
}
port++
}
Expect(tcpListener).ShouldNot(BeNil())
Expect(err).Should(BeNil())
return tcpListener, port
}

Expand Down
163 changes: 3 additions & 160 deletions pkg/lorry/engines/kafka/thirdparty/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@ package thirdparty
import (
"context"
"fmt"
"reflect"
"strconv"
"strings"
"sync/atomic"
"time"
"unicode"

"github.com/mitchellh/mapstructure"

"github.com/cenkalti/backoff/v4"
"github.com/pkg/errors"

"github.com/apecloud/kubeblocks/pkg/lorry/util/config"
)

// PolicyType denotes if the back off delay should be constant or exponential.
Expand Down Expand Up @@ -141,162 +139,7 @@ func DecodeConfig(c *Config, input interface{}) error {
*c = DefaultConfig()
}

return Decode(input, c)
}
func Decode(input interface{}, output interface{}) error {
decoder, err := mapstructure.NewDecoder(
&mapstructure.DecoderConfig{ //nolint: exhaustruct
Result: output,
DecodeHook: decodeString,
})
if err != nil {
return err
}

return decoder.Decode(input)
}

var (
typeDuration = reflect.TypeOf(time.Duration(5)) //nolint: gochecknoglobals
typeTime = reflect.TypeOf(time.Time{}) //nolint: gochecknoglobals
typeStringDecoder = reflect.TypeOf((*StringDecoder)(nil)).Elem() //nolint: gochecknoglobals
)

type StringDecoder interface {
DecodeString(value string) error
}

//nolint:cyclop
func decodeString(f reflect.Type, t reflect.Type, data any) (any, error) {
if t.Kind() == reflect.String && f.Kind() != reflect.String {
return fmt.Sprintf("%v", data), nil
}
if f.Kind() == reflect.Ptr {
f = f.Elem()
data = reflect.ValueOf(data).Elem().Interface()
}
if f.Kind() != reflect.String {
return data, nil
}

dataString, ok := data.(string)
if !ok {
return nil, errors.Errorf("expected string: got %s", reflect.TypeOf(data))
}

var result any
var decoder StringDecoder

if t.Implements(typeStringDecoder) {
result = reflect.New(t.Elem()).Interface()
decoder = result.(StringDecoder)
} else if reflect.PtrTo(t).Implements(typeStringDecoder) {
result = reflect.New(t).Interface()
decoder = result.(StringDecoder)
}

if decoder != nil {
if err := decoder.DecodeString(dataString); err != nil {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}

return nil, errors.Errorf("invalid %s %q: %v", t.Name(), dataString, err)
}

return result, nil
}

switch t {
case typeDuration:
// Check for simple integer values and treat them
// as milliseconds
if val, err := strconv.Atoi(dataString); err == nil {
return time.Duration(val) * time.Millisecond, nil
}

// Convert it by parsing
d, err := time.ParseDuration(dataString)

return d, invalidError(err, "duration", dataString)
case typeTime:
// Convert it by parsing
t, err := time.Parse(time.RFC3339Nano, dataString)
if err == nil {
return t, nil
}
t, err = time.Parse(time.RFC3339, dataString)

return t, invalidError(err, "time", dataString)
}

switch t.Kind() {
case reflect.Uint:
val, err := strconv.ParseUint(dataString, 10, 32)

return uint(val), invalidError(err, "uint", dataString)
case reflect.Uint64:
val, err := strconv.ParseUint(dataString, 10, 64)

return val, invalidError(err, "uint64", dataString)
case reflect.Uint32:
val, err := strconv.ParseUint(dataString, 10, 32)

return uint32(val), invalidError(err, "uint32", dataString)
case reflect.Uint16:
val, err := strconv.ParseUint(dataString, 10, 16)

return uint16(val), invalidError(err, "uint16", dataString)
case reflect.Uint8:
val, err := strconv.ParseUint(dataString, 10, 8)

return uint8(val), invalidError(err, "uint8", dataString)

case reflect.Int:
val, err := strconv.Atoi(dataString)

return val, invalidError(err, "int", dataString)
case reflect.Int64:
val, err := strconv.ParseInt(dataString, 10, 64)

return val, invalidError(err, "int64", dataString)
case reflect.Int32:
val, err := strconv.ParseInt(dataString, 10, 32)

return int32(val), invalidError(err, "int32", dataString)
case reflect.Int16:
val, err := strconv.ParseInt(dataString, 10, 16)

return int16(val), invalidError(err, "int16", dataString)
case reflect.Int8:
val, err := strconv.ParseInt(dataString, 10, 8)

return int8(val), invalidError(err, "int8", dataString)

case reflect.Float32:
val, err := strconv.ParseFloat(dataString, 32)

return float32(val), invalidError(err, "float32", dataString)
case reflect.Float64:
val, err := strconv.ParseFloat(dataString, 64)

return val, invalidError(err, "float64", dataString)

case reflect.Bool:
val, err := strconv.ParseBool(dataString)

return val, invalidError(err, "bool", dataString)

default:
return data, nil
}
}
func invalidError(err error, msg, value string) error {
if err == nil {
return nil
}

return errors.Errorf("invalid %s %q", msg, value)
return config.Decode(input, c)
}

// NotifyRecover is a wrapper around backoff.RetryNotify that adds another callback for when an operation
Expand Down
6 changes: 3 additions & 3 deletions pkg/lorry/util/config/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func decodeString(f reflect.Type, t reflect.Type, data any) (any, error) {

switch t.Kind() {
case reflect.Uint:
val, err := strconv.ParseUint(dataString, 10, 64)
val, err := strconv.ParseUint(dataString, 10, 32)

return uint(val), invalidError(err, "uint", dataString)
case reflect.Uint64:
Expand All @@ -144,9 +144,9 @@ func decodeString(f reflect.Type, t reflect.Type, data any) (any, error) {
return uint8(val), invalidError(err, "uint8", dataString)

case reflect.Int:
val, err := strconv.ParseInt(dataString, 10, 64)
val, err := strconv.Atoi(dataString)

return int(val), invalidError(err, "int", dataString)
return val, invalidError(err, "int", dataString)
case reflect.Int64:
val, err := strconv.ParseInt(dataString, 10, 64)

Expand Down

0 comments on commit 713a25c

Please sign in to comment.