Skip to content

Commit

Permalink
BeforeHandler return context
Browse files Browse the repository at this point in the history
  • Loading branch information
codingcn authored and felipejfc committed Dec 26, 2020
1 parent a215e15 commit dc45a18
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 33 deletions.
8 changes: 4 additions & 4 deletions defaultpipelines/default_struct_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ type DefaultValidator struct {
// based on the struct tags the parameter has.
// This function has the pipeline.Handler signature so
// it is possible to use it as a pipeline function
func (v *DefaultValidator) Validate(ctx context.Context, in interface{}) (interface{}, error) {
func (v *DefaultValidator) Validate(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
if in == nil {
return in, nil
return ctx, in, nil
}

v.lazyinit()
if err := v.validate.Struct(in); err != nil {
return nil, err
return ctx, nil, err
}

return in, nil
return ctx, in, nil
}

func (v *DefaultValidator) lazyinit() {
Expand Down
4 changes: 2 additions & 2 deletions defaultpipelines/default_struct_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func TestDefaultValidator(t *testing.T) {
t.Run(tname, func(t *testing.T) {
var err error
if tbl.s == nil {
_, err = validator.Validate(context.Background(), nil)
_, _, err = validator.Validate(context.Background(), nil)
} else {
_, err = validator.Validate(context.Background(), tbl.s)
_, _, err = validator.Validate(context.Background(), tbl.s)
}

if tbl.shouldFail {
Expand Down
2 changes: 1 addition & 1 deletion defaultpipelines/struct_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
//
// The default struct validator used by pitaya is https://github.com/go-playground/validator.
type StructValidator interface {
Validate(context.Context, interface{}) (interface{}, error)
Validate(context.Context, interface{}) (context.Context, interface{}, error)
}

// StructValidatorInstance holds the default validator
Expand Down
2 changes: 1 addition & 1 deletion pipeline/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var (
type (
// HandlerTempl is a function that has the same signature as a handler and will
// be called before or after handler methods
HandlerTempl func(ctx context.Context, in interface{}) (out interface{}, err error)
HandlerTempl func(ctx context.Context, in interface{}) (c context.Context, out interface{}, err error)

// AfterHandlerTempl is a function for the after handler, receives both the handler response
// and the error returned
Expand Down
12 changes: 6 additions & 6 deletions pipeline/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import (
)

var (
handler1 = func(ctx context.Context, in interface{}) (interface{}, error) {
return in, errors.New("ohno")
handler1 = func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, in, errors.New("ohno")
}
handler2 = func(ctx context.Context, in interface{}) (interface{}, error) {
return nil, nil
handler2 = func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, nil, nil
}
p = &pipelineChannel{}
)
Expand All @@ -43,7 +43,7 @@ func TestPushFront(t *testing.T) {
p.PushFront(handler2)
defer p.Clear()

_, err := p.Handlers[0](nil, nil)
_, _, err := p.Handlers[0](nil, nil)
assert.Nil(t, nil, err)
}

Expand All @@ -52,7 +52,7 @@ func TestPushBack(t *testing.T) {
p.PushBack(handler2)
defer p.Clear()

_, err := p.Handlers[0](nil, nil)
_, _, err := p.Handlers[0](nil, nil)
assert.EqualError(t, errors.New("ohno"), err.Error())
}

Expand Down
8 changes: 5 additions & 3 deletions pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ func resetPipelines() {
pipeline.AfterHandler.Handlers = make([]pipeline.AfterHandlerTempl, 0)
}

var myHandler = func(ctx context.Context, in interface{}) (interface{}, error) {
return []byte("test"), nil
var myHandler = func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
ctx = context.WithValue(ctx, "traceID", "123456")
return ctx, []byte("test"), nil
}

var myAfterHandler = func(ctx context.Context, out interface{}, err error) (interface{}, error) {
Expand All @@ -44,9 +45,10 @@ var myAfterHandler = func(ctx context.Context, out interface{}, err error) (inte
func TestBeforeHandler(t *testing.T) {
resetPipelines()
BeforeHandler(myHandler)
r, err := pipeline.BeforeHandler.Handlers[0](nil, nil)
ctx, r, err := pipeline.BeforeHandler.Handlers[0](nil, nil)
assert.NoError(t, err)
assert.Equal(t, []byte("test"), r)
assert.Equal(t, "123456", ctx.Value("traceID").(string))
}

func TestAfterHandler(t *testing.T) {
Expand Down
11 changes: 6 additions & 5 deletions service/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,19 @@ func getMsgType(msgTypeIface interface{}) (message.Type, error) {
return msgType, nil
}

func executeBeforePipeline(ctx context.Context, data interface{}) (interface{}, error) {
func executeBeforePipeline(ctx context.Context, data interface{}) (context.Context, interface{}, error) {
var err error
res := data
if len(pipeline.BeforeHandler.Handlers) > 0 {
for _, h := range pipeline.BeforeHandler.Handlers {
res, err = h(ctx, res)
ctx, _, err = h(ctx, res)
if err != nil {
logger.Log.Debugf("pitaya/handler: broken pipeline: %s", err.Error())
return res, err
return ctx, res, err
}
}
}
return res, nil
return ctx, res, nil
}

func executeAfterPipeline(ctx context.Context, res interface{}, err error) (interface{}, error) {
Expand Down Expand Up @@ -174,7 +174,8 @@ func processHandlerMessage(
return nil, e.NewError(err, e.ErrBadRequestCode)
}

if arg, err = executeBeforePipeline(ctx, arg); err != nil {
ctx, arg, err = executeBeforePipeline(ctx, arg)
if err != nil {
return nil, err
}

Expand Down
22 changes: 11 additions & 11 deletions service/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func TestGetMsgType(t *testing.T) {

func TestExecuteBeforePipelineEmpty(t *testing.T) {
expected := []byte("ok")
res, err := executeBeforePipeline(nil, expected)
_, res, err := executeBeforePipeline(nil, expected)
assert.NoError(t, err)
assert.Equal(t, expected, res)
}
Expand All @@ -217,36 +217,36 @@ func TestExecuteBeforePipelineSuccess(t *testing.T) {
data := []byte("ok")
expected1 := []byte("oh noes 1")
expected2 := []byte("oh noes 2")
before1 := func(ctx context.Context, in interface{}) (interface{}, error) {
before1 := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
assert.Equal(t, data, in)
return expected1, nil
return ctx, expected1, nil
}
before2 := func(ctx context.Context, in interface{}) (interface{}, error) {
before2 := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
assert.Equal(t, expected1, in)
return expected2, nil
return ctx, expected2, nil
}
pipeline.BeforeHandler.PushBack(before1)
pipeline.BeforeHandler.PushBack(before2)
defer pipeline.BeforeHandler.Clear()

res, err := executeBeforePipeline(c, data)
_, res, err := executeBeforePipeline(c, data)
assert.NoError(t, err)
assert.Equal(t, expected2, res)
}

func TestExecuteBeforePipelineError(t *testing.T) {
c := context.Background()
expected := errors.New("oh noes")
before := func(ctx context.Context, in interface{}) (interface{}, error) {
before := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
return nil, expected
return ctx, nil, expected
}
pipeline.BeforeHandler.PushFront(before)
defer pipeline.BeforeHandler.Clear()

_, err := executeBeforePipeline(c, []byte("ok"))
_, _, err := executeBeforePipeline(c, []byte("ok"))
assert.Equal(t, expected, err)
}

Expand Down Expand Up @@ -408,8 +408,8 @@ func TestProcessHandlerMessageBrokenBeforePipeline(t *testing.T) {
handlers[rt.Short()] = &component.Handler{}
defer func() { delete(handlers, rt.Short()) }()
expected := errors.New("oh noes")
before := func(ctx context.Context, in interface{}) (interface{}, error) {
return nil, expected
before := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, nil, expected
}
pipeline.BeforeHandler.PushFront(before)
defer pipeline.BeforeHandler.Clear()
Expand Down

0 comments on commit dc45a18

Please sign in to comment.