Skip to content

Commit

Permalink
adds shutdown timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
Shubhang Balkundi committed Mar 30, 2024
1 parent d7dd9eb commit 26af49d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 13 deletions.
16 changes: 8 additions & 8 deletions logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,35 @@ import (
"github.com/rs/zerolog"
)

type textLogger struct {
type TextLogger struct {
l zerolog.Logger
}

func (h *textLogger) Info(message string, kvs ...map[string]interface{}) {
func (h *TextLogger) Info(message string, kvs ...map[string]interface{}) {
appendFields(h.l.Info(), kvs).Msg(message)
}

func (h *textLogger) Debug(message string, kvs ...map[string]interface{}) {
func (h *TextLogger) Debug(message string, kvs ...map[string]interface{}) {
appendFields(h.l.Debug(), kvs).Msg(message)
}

func (h *textLogger) Warn(message string, kvs ...map[string]interface{}) {
func (h *TextLogger) Warn(message string, kvs ...map[string]interface{}) {
appendFields(h.l.Warn(), kvs).Msg(message)
}

func (h *textLogger) Error(message string, err error, kvs ...map[string]interface{}) {
func (h *TextLogger) Error(message string, err error, kvs ...map[string]interface{}) {
if err != nil {
appendFields(h.l.Err(err), kvs).Msg(message)
}
}

func (h *textLogger) Fatal(message string, err error, kvs ...map[string]interface{}) {
func (h *TextLogger) Fatal(message string, err error, kvs ...map[string]interface{}) {
if err != nil {
appendFields(h.l.Fatal(), kvs).Msg(message)
}
}

func NewLogger(level string, opts ...func(w *zerolog.ConsoleWriter)) *textLogger {
func NewLogger(level string, opts ...func(w *zerolog.ConsoleWriter)) *TextLogger {
cw := zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
w.FormatLevel = func(i interface{}) string {
return fmt.Sprintf("[%s]", i)
Expand All @@ -48,5 +48,5 @@ func NewLogger(level string, opts ...func(w *zerolog.ConsoleWriter)) *textLogger
}
})
l := zerolog.New(cw).With().Timestamp().Logger().Level(logLevelMapping[level])
return &textLogger{l: l}
return &TextLogger{l: l}
}
26 changes: 22 additions & 4 deletions zigg.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"github.com/gojekfarm/ziggurat/v2/logger"
"sync"
"time"
)

var ErrCleanShutdown = errors.New("clean shutdown of streams")
Expand All @@ -14,9 +15,10 @@ var ErrCleanShutdown = errors.New("clean shutdown of streams")
// var z ziggurat.Ziggurat
// z.run(ctx context.Context,s ziggurat.MessageConsumer,h ziggurat.Handler)
type Ziggurat struct {
handler Handler
Logger StructuredLogger
ErrorHandler func(err error)
handler Handler
Logger StructuredLogger
ShutdownTimeout time.Duration
ErrorHandler func(err error)
}

func (z *Ziggurat) Run(ctx context.Context, handler Handler, consumers ...MessageConsumer) error {
Expand All @@ -36,9 +38,19 @@ func (z *Ziggurat) Run(ctx context.Context, handler Handler, consumers ...Messag
}(i)
}

timeout := make(chan bool, 1)
go func() {
<-ctx.Done()
<-time.After(z.ShutdownTimeout)
z.Logger.Info("ziggurat consumer orchestration wait timeout")
timeout <- true
close(errChan)
}()

go func() {
wg.Wait()
close(errChan)
close(timeout)
}()

var allErrs []error
Expand All @@ -49,6 +61,10 @@ func (z *Ziggurat) Run(ctx context.Context, handler Handler, consumers ...Messag
allErrs = append(allErrs, consErr)
}

if <-timeout {
return errors.New("shutdown timeout")
}

if len(allErrs) > 0 {
return errors.Join(allErrs...)
}
Expand All @@ -61,12 +77,14 @@ func (z *Ziggurat) mustInit(consumers []MessageConsumer, handler Handler) {
if z.Logger == nil {
z.Logger = logger.NOOP
}
if z.ShutdownTimeout == 0 {
z.ShutdownTimeout = 6000 * time.Millisecond
}
if len(consumers) < 1 {
panic("error: at least one ziggurat.MessageConsumer implementation should be provided")
}

if handler == nil {
panic("error: handler cannot be nil")
}

}
45 changes: 44 additions & 1 deletion zigg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ziggurat
import (
"context"
"errors"
"github.com/gojekfarm/ziggurat/v2/logger"
"github.com/stretchr/testify/mock"
"sync/atomic"
"testing"
Expand All @@ -11,10 +12,14 @@ import (

type MockConsumer struct {
mock.Mock
PollInterval time.Duration
}

func (m *MockConsumer) Consume(ctx context.Context, handler Handler) error {
args := m.Called(ctx, handler)
if m.PollInterval == 0 {
m.PollInterval = 200 * time.Millisecond
}

keepAlive := true

Expand All @@ -23,7 +28,7 @@ func (m *MockConsumer) Consume(ctx context.Context, handler Handler) error {
case <-ctx.Done():
keepAlive = false
default:
time.Sleep(200 * time.Millisecond)
time.Sleep(m.PollInterval)
handler.Handle(ctx, &Event{})
}
}
Expand Down Expand Up @@ -96,4 +101,42 @@ func TestZiggurat_Run(t *testing.T) {
t.Logf("message count:%d", msgCount)
})

t.Run("test shutdown timeout", func(t *testing.T) {
var zig Ziggurat
zig.ShutdownTimeout = 250 * time.Millisecond
zig.Logger = logger.NewLogger(logger.LevelInfo)

errCount := 1
zig.ErrorHandler = func(err error) {
errCount++
}

ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
defer cancel()
mc1 := MockConsumer{PollInterval: 10000 * time.Second}
mc2 := MockConsumer{}
mc3 := MockConsumer{}
handler := HandlerFunc(func(ctx context.Context, event *Event) {})

mc1.On("Consume", mock.Anything, mock.Anything).Return(nil)
mc2.On("Consume", mock.Anything, mock.Anything).Return(errors.New("mc2 errored out"))
mc3.On("Consume", mock.Anything, mock.Anything).Return(nil)

err := zig.Run(ctx, handler, &mc1, &mc2, &mc3)
if err == nil {
t.Error("expected error got nil")
return
}

if err.Error() != "shutdown timeout" {
t.Errorf("expected error:%s got %s\n", "shutdown timeout", err.Error())
return
}

if errCount < 1 {
t.Errorf("expected an error count of 1 got %d\n", errCount)
}

})

}

0 comments on commit 26af49d

Please sign in to comment.