Skip to content

Commit

Permalink
Added secret token validation
Browse files Browse the repository at this point in the history
  • Loading branch information
mymmrac committed Feb 7, 2023
1 parent 0ee7efc commit 38d796d
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 21 deletions.
8 changes: 5 additions & 3 deletions examples/updates_webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ func main() {
telego.WithWebhookBuffer(128),

// Set fast http server that will be used to handle webhooks (default telego.FastHTTPWebhookServer)
// Note: If SecretToken is non-empty, it will be verified on each request
telego.WithWebhookServer(telego.FastHTTPWebhookServer{
Logger: bot.Logger(),
Server: &fasthttp.Server{},
Router: router.New(),
Logger: bot.Logger(),
Server: &fasthttp.Server{},
Router: router.New(),
SecretToken: "token",
}),

// Calls SetWebhook before starting webhook
Expand Down
67 changes: 55 additions & 12 deletions webhook_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@ import (
"github.com/valyala/fasthttp"
)

// FastHTTPWebhookServer represents fasthttp implementation of WebhookServer
// Note: The user should set both Server and Router, only Logger is optional
// WebhookSecretTokenHeader represents secret token header name, see SetWebhookParams.SecretToken for more details
const WebhookSecretTokenHeader = "X-Telegram-Bot-Api-Secret-Token" //nolint:gosec

// FastHTTPWebhookServer represents fasthttp implementation of WebhookServer.
// The Server and Router are required fields, optional Logger and SecretToken can be provided.
type FastHTTPWebhookServer struct {
Logger Logger
Server *fasthttp.Server
Router *router.Router
Logger Logger
Server *fasthttp.Server
Router *router.Router
SecretToken string
}

// Start starts server
Expand All @@ -33,6 +37,18 @@ func (f FastHTTPWebhookServer) Stop(ctx context.Context) error {
// Note: If server's handler is not set, it will be set to router's handler
func (f FastHTTPWebhookServer) RegisterHandler(path string, handler func(data []byte) error) error {
f.Router.POST(path, func(ctx *fasthttp.RequestCtx) {
if f.SecretToken != "" {
secretToken := ctx.Request.Header.Peek(WebhookSecretTokenHeader)
if f.SecretToken != string(secretToken) {
if f.Logger != nil {
f.Logger.Errorf("Webhook handler: unauthorized: secret token does not match")
}

ctx.SetStatusCode(fasthttp.StatusUnauthorized)
return
}
}

if err := handler(ctx.PostBody()); err != nil {
if f.Logger != nil {
f.Logger.Errorf("Webhook handler: %s", err)
Expand All @@ -52,12 +68,13 @@ func (f FastHTTPWebhookServer) RegisterHandler(path string, handler func(data []
return nil
}

// HTTPWebhookServer represents http implementation of WebhookServer
// Note: The user should set both Server and ServeMux, only Logger is optional
// HTTPWebhookServer represents http implementation of WebhookServer.
// The Server and ServeMux are required fields, optional Logger and SecretToken can be provided.
type HTTPWebhookServer struct {
Logger Logger
Server *http.Server
ServeMux *http.ServeMux
Logger Logger
Server *http.Server
ServeMux *http.ServeMux
SecretToken string
}

// Start starts server
Expand All @@ -81,8 +98,7 @@ func (h HTTPWebhookServer) Stop(ctx context.Context) error {
// Note: If server's handler is not set, it will be set to serve mux handler
func (h HTTPWebhookServer) RegisterHandler(path string, handler func(data []byte) error) error {
h.ServeMux.HandleFunc(path, func(writer http.ResponseWriter, request *http.Request) {
if request.Method != http.MethodPost {
writer.WriteHeader(http.StatusMethodNotAllowed)
if !h.validateRequest(writer, request) {
return
}

Expand All @@ -96,6 +112,12 @@ func (h HTTPWebhookServer) RegisterHandler(path string, handler func(data []byte
return
}

if err = request.Body.Close(); err != nil {
if h.Logger != nil {
h.Logger.Errorf("Webhook handler: close body: %s", err)
}
}

if err = handler(data); err != nil {
if h.Logger != nil {
h.Logger.Errorf("Webhook handler: %s", err)
Expand All @@ -115,6 +137,27 @@ func (h HTTPWebhookServer) RegisterHandler(path string, handler func(data []byte
return nil
}

func (h HTTPWebhookServer) validateRequest(writer http.ResponseWriter, request *http.Request) bool {
if request.Method != http.MethodPost {
writer.WriteHeader(http.StatusMethodNotAllowed)
return false
}

if h.SecretToken != "" {
secretToken := request.Header.Get(WebhookSecretTokenHeader)
if h.SecretToken != secretToken {
if h.Logger != nil {
h.Logger.Errorf("Webhook handler: unauthorized: secret token does not match")
}

writer.WriteHeader(http.StatusUnauthorized)
return false
}
}

return true
}

// MultiBotWebhookServer represents multi bot implementation of WebhookServer,
// stable for running multiple bots from single server
type MultiBotWebhookServer struct {
Expand Down
60 changes: 54 additions & 6 deletions webhook_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package telego

import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -18,9 +19,10 @@ func TestFastHTTPWebhookServer_RegisterHandler(t *testing.T) {
addr := testAddress(t)

s := FastHTTPWebhookServer{
Logger: testLoggerType{},
Server: &fasthttp.Server{},
Router: router.New(),
Logger: testLoggerType{},
Server: &fasthttp.Server{},
Router: router.New(),
SecretToken: "secret",
}

go func() {
Expand All @@ -41,6 +43,7 @@ func TestFastHTTPWebhookServer_RegisterHandler(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
ctx.Request.Header.Set(WebhookSecretTokenHeader, s.SecretToken)
s.Server.Handler(ctx)

assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
Expand All @@ -59,12 +62,22 @@ func TestFastHTTPWebhookServer_RegisterHandler(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
ctx.Request.Header.Set(WebhookSecretTokenHeader, s.SecretToken)
ctx.Request.SetBody([]byte("err"))
s.Server.Handler(ctx)

assert.Equal(t, fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
})

t.Run("secret_token_invalid", func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
s.Server.Handler(ctx)

assert.Equal(t, fasthttp.StatusUnauthorized, ctx.Response.StatusCode())
})

err = s.Stop(context.Background())
assert.NoError(t, err)
}
Expand All @@ -91,9 +104,10 @@ func TestHTTPWebhookServer_RegisterHandler(t *testing.T) {

t.Run("end_to_end", func(t *testing.T) {
s := HTTPWebhookServer{
Logger: testLoggerType{},
Server: &http.Server{}, //nolint:gosec
ServeMux: http.NewServeMux(),
Logger: testLoggerType{},
Server: &http.Server{}, //nolint:gosec
ServeMux: http.NewServeMux(),
SecretToken: "secret",
}

go func() {
Expand All @@ -113,6 +127,7 @@ func TestHTTPWebhookServer_RegisterHandler(t *testing.T) {
t.Run("success", func(t *testing.T) {
rc := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set(WebhookSecretTokenHeader, s.SecretToken)

s.Server.Handler.ServeHTTP(rc, req)

Expand All @@ -131,6 +146,7 @@ func TestHTTPWebhookServer_RegisterHandler(t *testing.T) {
t.Run("error_handler", func(t *testing.T) {
rc := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("err"))
req.Header.Set(WebhookSecretTokenHeader, s.SecretToken)

s.Server.Handler.ServeHTTP(rc, req)

Expand All @@ -140,12 +156,32 @@ func TestHTTPWebhookServer_RegisterHandler(t *testing.T) {
t.Run("error_read", func(t *testing.T) {
rc := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", errReader{})
req.Header.Set(WebhookSecretTokenHeader, s.SecretToken)

s.Server.Handler.ServeHTTP(rc, req)

assert.Equal(t, http.StatusInternalServerError, rc.Code)
})

t.Run("error_close", func(t *testing.T) {
rc := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", errReaderCloser{reader: strings.NewReader("ok")})
req.Header.Set(WebhookSecretTokenHeader, s.SecretToken)

s.Server.Handler.ServeHTTP(rc, req)

assert.Equal(t, http.StatusInternalServerError, rc.Code)
})

t.Run("secret_token_invalid", func(t *testing.T) {
rc := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", nil)

s.Server.Handler.ServeHTTP(rc, req)

assert.Equal(t, http.StatusUnauthorized, rc.Code)
})

err = s.Stop(context.Background())
assert.NoError(t, err)
})
Expand All @@ -157,6 +193,18 @@ func (e errReader) Read(_ []byte) (n int, err error) {
return 0, errTest
}

type errReaderCloser struct {
reader io.Reader
}

func (e errReaderCloser) Close() error {
return errTest
}

func (e errReaderCloser) Read(b []byte) (n int, err error) {
return e.reader.Read(b)
}

func TestMultiBotWebhookServer_RegisterHandler(t *testing.T) {
ts := &testServer{}
s := &MultiBotWebhookServer{
Expand Down

0 comments on commit 38d796d

Please sign in to comment.