Skip to content

Commit

Permalink
Merge pull request redis#1479 from go-redis/fix/hook-call-after
Browse files Browse the repository at this point in the history
Make sure to call after hook on error
  • Loading branch information
vmihailenco authored Sep 11, 2020
2 parents fb80d42 + 69287d7 commit b67982d
Showing 1 changed file with 34 additions and 54 deletions.
88 changes: 34 additions & 54 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,83 +48,63 @@ func (hs *hooks) AddHook(hook Hook) {
func (hs hooks) process(
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
) error {
ctx, err := hs.beforeProcess(ctx, cmd)
if err != nil {
cmd.SetErr(err)
return err
if len(hs.hooks) == 0 {
return fn(ctx, cmd)
}

cmdErr := fn(ctx, cmd)
var hookIndex int
var retErr error

if err := hs.afterProcess(ctx, cmd); err != nil {
cmd.SetErr(err)
return err
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
if retErr != nil {
cmd.SetErr(retErr)
}
}

return cmdErr
}

func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcess(ctx, cmd)
if err != nil {
return nil, err
}
if retErr == nil {
retErr = fn(ctx, cmd)
}
return ctx, nil
}

func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) error {
var firstErr error
for i := len(hs.hooks) - 1; i >= 0; i-- {
h := hs.hooks[i]
if err := h.AfterProcess(ctx, cmd); err != nil && firstErr == nil {
firstErr = err
for hookIndex--; hookIndex >= 0; hookIndex-- {
if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
retErr = err
cmd.SetErr(retErr)
}
}
return firstErr

return retErr
}

func (hs hooks) processPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
if err != nil {
setCmdsErr(cmds, err)
return err
if len(hs.hooks) == 0 {
return fn(ctx, cmds)
}

cmdsErr := fn(ctx, cmds)
var hookIndex int
var retErr error

if err := hs.afterProcessPipeline(ctx, cmds); err != nil {
setCmdsErr(cmds, err)
return err
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
if retErr != nil {
setCmdsErr(cmds, retErr)
}
}

return cmdsErr
}

func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcessPipeline(ctx, cmds)
if err != nil {
return nil, err
}
if retErr == nil {
retErr = fn(ctx, cmds)
}
return ctx, nil
}

func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error {
var firstErr error
for i := len(hs.hooks) - 1; i >= 0; i-- {
h := hs.hooks[i]
if err := h.AfterProcessPipeline(ctx, cmds); err != nil && firstErr == nil {
firstErr = err
for hookIndex--; hookIndex >= 0; hookIndex-- {
if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
retErr = err
setCmdsErr(cmds, retErr)
}
}
return firstErr

return retErr
}

func (hs hooks) processTxPipeline(
Expand Down

0 comments on commit b67982d

Please sign in to comment.