diff --git a/apis/base.go b/apis/base.go index 1a682ee94..0c607b6d9 100644 --- a/apis/base.go +++ b/apis/base.go @@ -121,6 +121,10 @@ func InitApi(app core.App) (*echo.Echo, error) { return nil, err } + // note: it is after the OnBeforeServe hook to ensure that the implicit + // cache is after any user custom defined middlewares + e.Use(eagerRequestDataCache(app)) + // catch all any route api.Any("/*", func(c echo.Context) error { return echo.ErrNotFound diff --git a/apis/base_test.go b/apis/base_test.go index 947548bfb..ee830ec6f 100644 --- a/apis/base_test.go +++ b/apis/base_test.go @@ -2,12 +2,15 @@ package apis_test import ( "errors" + "fmt" "net/http" + "strings" "testing" "github.com/labstack/echo/v5" "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/tests" + "github.com/spf13/cast" ) func Test404(t *testing.T) { @@ -209,3 +212,83 @@ func TestRemoveTrailingSlashMiddleware(t *testing.T) { scenario.Test(t) } } + +func TestEagerRequestDataCache(t *testing.T) { + scenarios := []tests.ApiScenario{ + { + Name: "[UNKNOWN] unsupported eager cached request method", + Method: "UNKNOWN", + Url: "/custom", + Body: strings.NewReader(`{"name":"test123"}`), + BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { + e.AddRoute(echo.Route{ + Method: "UNKNOWN", + Path: "/custom", + Handler: func(c echo.Context) error { + data := &struct { + Name string `json:"name"` + }{} + + if err := c.Bind(data); err != nil { + return err + } + + // since the unknown method is not eager cache support + // it should fail reading the json body twice + r := apis.RequestData(c) + if v := cast.ToString(r.Data["name"]); v != "" { + t.Fatalf("Expected empty request data body, got, %v", r.Data) + } + + return c.String(200, data.Name) + }, + }) + }, + ExpectedStatus: 200, + ExpectedContent: []string{"test123"}, + }, + } + + // supported eager cache request methods + supportedMethods := []string{"POST", "PUT", "PATCH", "DELETE"} + for _, m := range supportedMethods { + scenarios = append( + scenarios, + tests.ApiScenario{ + Name: fmt.Sprintf("[%s] valid cached json body request", m), + Method: http.MethodPost, + Url: "/custom", + Body: strings.NewReader(`{"name":"test123"}`), + BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { + e.AddRoute(echo.Route{ + Method: http.MethodPost, + Path: "/custom", + Handler: func(c echo.Context) error { + data := &struct { + Name string `json:"name"` + }{} + + if err := c.Bind(data); err != nil { + return err + } + + // try to read the body again + r := apis.RequestData(c) + if v := cast.ToString(r.Data["name"]); v != "test123" { + t.Fatalf("Expected request data with name %q, got, %q", "test123", v) + } + + return c.String(200, data.Name) + }, + }) + }, + ExpectedStatus: 200, + ExpectedContent: []string{"test123"}, + }, + ) + } + + for _, scenario := range scenarios { + scenario.Test(t) + } +} diff --git a/apis/middlewares.go b/apis/middlewares.go index a1c0a51de..e5276a5b7 100644 --- a/apis/middlewares.go +++ b/apis/middlewares.go @@ -385,3 +385,19 @@ func realUserIp(r *http.Request, fallbackIp string) string { return fallbackIp } + +// eagerRequestDataCache ensures that the request data is cached in the request +// context to allow reading for example the json request body data more than once. +func eagerRequestDataCache(app core.App) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + switch c.Request().Method { + // currently we are eagerly caching only the requests with body + case "POST", "PUT", "PATCH", "DELETE": + RequestData(c) + } + + return next(c) + } + } +}