Skip to content

Commit 7f7d99c

Browse files
cubic3draphael
authored andcommitted
Added Method Not Allowed default handler (goadesign#1374)
* Added Method Not Allowed default handler * Changed MethodNotAllowed description comments * MethodNotAllowed error distinguishes singular/plural
1 parent 4937b25 commit 7f7d99c

File tree

6 files changed

+146
-2
lines changed

6 files changed

+146
-2
lines changed

error.go

+15
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ var (
6969
// ErrNotFound is the error returned to requests that don't match a registered handler.
7070
ErrNotFound = NewErrorClass("not_found", 404)
7171

72+
// ErrMethodNotAllowed is the error returned to requests that match the path of a registered
73+
// handler but not the HTTP method.
74+
ErrMethodNotAllowed = NewErrorClass("method_not_allowed", 405)
75+
7276
// ErrInternal is the class of error used for uncaught errors.
7377
ErrInternal = NewErrorClass("internal", 500)
7478
)
@@ -245,6 +249,17 @@ func NoAuthMiddleware(schemeName string) error {
245249
return ErrNoAuthMiddleware(msg, "scheme", schemeName)
246250
}
247251

252+
// MethodNotAllowedError is the error produced to requests that match the path of a registered
253+
// handler but not the HTTP method.
254+
func MethodNotAllowedError(method string, allowed []string) error {
255+
var plural string
256+
if len(allowed) > 1 {
257+
plural = " one of"
258+
}
259+
msg := fmt.Sprintf("Method %s must be%s %s", method, plural, strings.Join(allowed, ", "))
260+
return ErrMethodNotAllowed(msg, "method", method, "allowed", strings.Join(allowed, ", "))
261+
}
262+
248263
// Error returns the error occurrence details.
249264
func (e *ErrorResponse) Error() string {
250265
msg := fmt.Sprintf("[%s] %d %s: %s", e.ID, e.Status, e.Code, e.Detail)

error_test.go

+41
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7+
"strings"
78

89
. "github.com/onsi/ginkgo"
910
. "github.com/onsi/gomega"
@@ -121,6 +122,46 @@ var _ = Describe("MissingHeaderError", func() {
121122
})
122123
})
123124

125+
var _ = Describe("MethodNotAllowedError", func() {
126+
var valErr error
127+
method := "POST"
128+
var allowed []string
129+
130+
JustBeforeEach(func() {
131+
valErr = MethodNotAllowedError(method, allowed)
132+
})
133+
134+
BeforeEach(func() {
135+
allowed = []string{"OPTIONS", "GET"}
136+
})
137+
138+
It("creates a http error", func() {
139+
Ω(valErr).ShouldNot(BeNil())
140+
Ω(valErr).Should(BeAssignableToTypeOf(&ErrorResponse{}))
141+
err := valErr.(*ErrorResponse)
142+
Ω(err.Detail).Should(ContainSubstring(method))
143+
Ω(err.Detail).Should(ContainSubstring(strings.Join(allowed, ", ")))
144+
})
145+
146+
Context("multiple allowed methods", func() {
147+
It("should use plural", func() {
148+
err := valErr.(*ErrorResponse)
149+
Ω(err.Detail).Should(ContainSubstring("one of"))
150+
})
151+
})
152+
153+
Context("single allowed method", func() {
154+
BeforeEach(func() {
155+
allowed = []string{"GET"}
156+
})
157+
158+
It("should not use plural", func() {
159+
err := valErr.(*ErrorResponse)
160+
Ω(err.Detail).ShouldNot(ContainSubstring("one of"))
161+
})
162+
})
163+
})
164+
124165
var _ = Describe("InvalidEnumValueError", func() {
125166
var valErr error
126167
ctx := "ctx"

mux.go

+16-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ type (
1212
// The values argument includes both the querystring and path parameter values.
1313
MuxHandler func(http.ResponseWriter, *http.Request, url.Values)
1414

15+
// MethodNotAllowedHandler provides the implementation for an MethodNotAllowed
16+
// handler. The values argument includes both the querystring and path parameter
17+
// values. The methods argument includes both the allowed method identifier
18+
// and the registered handler.
19+
MethodNotAllowedHandler func(http.ResponseWriter, *http.Request, url.Values, map[string]httptreemux.HandlerFunc)
20+
1521
// ServeMux is the interface implemented by the service request muxes.
1622
// It implements http.Handler and makes it possible to register request handlers for
1723
// specific HTTP methods and request path via the Handle method.
@@ -23,6 +29,10 @@ type (
2329
// handler registered with Handle. The values argument given to the handler is
2430
// always nil.
2531
HandleNotFound(handle MuxHandler)
32+
// HandleMethodNotAllowed sets the MethodNotAllowedHandler invoked for requests
33+
// that match the path of a handler but not its HTTP method. The values argument
34+
// given to the Handler is always nil.
35+
HandleMethodNotAllowed(handle MethodNotAllowedHandler)
2636
// Lookup returns the MuxHandler associated with the given HTTP method and path.
2737
Lookup(method, path string) MuxHandler
2838
}
@@ -69,8 +79,13 @@ func (m *mux) HandleNotFound(handle MuxHandler) {
6979
handle(rw, req, nil)
7080
}
7181
m.router.NotFoundHandler = nfh
82+
}
83+
84+
// HandleMethodNotAllowed sets the MuxHandler invoked for requests that match
85+
// the path of a handler but not its HTTP method.
86+
func (m *mux) HandleMethodNotAllowed(handle MethodNotAllowedHandler) {
7287
mna := func(rw http.ResponseWriter, req *http.Request, methods map[string]httptreemux.HandlerFunc) {
73-
handle(rw, req, nil)
88+
handle(rw, req, nil, methods)
7489
}
7590
m.router.MethodNotAllowedHandler = mna
7691
}

mux_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,21 @@ var _ = Describe("Mux", func() {
6666
})
6767
})
6868

69+
Context("with registered handlers and wrong method", func() {
70+
const handlerMeth = "POST"
71+
const reqMeth = "GET"
72+
const reqPath = "/foo"
73+
74+
BeforeEach(func() {
75+
var err error
76+
req, err = http.NewRequest(reqMeth, reqPath, nil)
77+
Ω(err).ShouldNot(HaveOccurred())
78+
mux.Handle(handlerMeth, reqPath, func(rw http.ResponseWriter, req *http.Request, vals url.Values) {})
79+
})
80+
81+
It("returns 405 to not allowed method", func() {
82+
Ω(rw.Status).Should(Equal(405))
83+
})
84+
})
85+
6986
})

service.go

+35-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"strings"
1414

1515
"context"
16+
17+
"github.com/dimfeld/httptreemux"
1618
)
1719

1820
type (
@@ -100,7 +102,8 @@ func New(name string) *Service {
100102

101103
cancel: cancel,
102104
}
103-
notFoundHandler Handler
105+
notFoundHandler Handler
106+
methodNotAllowedHandler Handler
104107
)
105108

106109
// Setup default NotFound handler
@@ -127,6 +130,37 @@ func New(name string) *Service {
127130
}
128131
})
129132

133+
// Setup default MethodNotAllowed handler
134+
mux.HandleMethodNotAllowed(func(rw http.ResponseWriter, req *http.Request, params url.Values, methods map[string]httptreemux.HandlerFunc) {
135+
if resp := ContextResponse(ctx); resp != nil && resp.Written() {
136+
return
137+
}
138+
// Use closure to do lazy computation of middleware chain so all middlewares are
139+
// registered.
140+
if methodNotAllowedHandler == nil {
141+
methodNotAllowedHandler = func(_ context.Context, rw http.ResponseWriter, req *http.Request) error {
142+
allowedMethods := make([]string, len(methods))
143+
i := 0
144+
for k := range methods {
145+
allowedMethods[i] = k
146+
i++
147+
}
148+
rw.Header().Set("Allow", strings.Join(allowedMethods, ", "))
149+
return MethodNotAllowedError(req.Method, allowedMethods)
150+
}
151+
chain := service.middleware
152+
ml := len(chain)
153+
for i := range chain {
154+
methodNotAllowedHandler = chain[ml-i-1](methodNotAllowedHandler)
155+
}
156+
}
157+
ctx := NewContext(service.Context, rw, req, params)
158+
err := methodNotAllowedHandler(ctx, ContextResponse(ctx), req)
159+
if !ContextResponse(ctx).Written() {
160+
service.Send(ctx, 405, err)
161+
}
162+
})
163+
130164
return service
131165
}
132166

service_test.go

+22
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,28 @@ var _ = Describe("Service", func() {
8585
})
8686
})
8787

88+
Describe("MethodNotAllowed", func() {
89+
var rw *TestResponseWriter
90+
var req *http.Request
91+
92+
JustBeforeEach(func() {
93+
rw = &TestResponseWriter{ParentHeader: http.Header{}}
94+
s.Mux.ServeHTTP(rw, req)
95+
})
96+
97+
BeforeEach(func() {
98+
req, _ = http.NewRequest("GET", "/foo", nil)
99+
s.Mux.Handle("POST", "/foo", func(rw http.ResponseWriter, req *http.Request, vals url.Values) {})
100+
s.Mux.Handle("PUT", "/foo", func(rw http.ResponseWriter, req *http.Request, vals url.Values) {})
101+
})
102+
103+
It("handles requests with wrong method but existing endpoint", func() {
104+
Ω(rw.Status).Should(Equal(405))
105+
Ω(rw.Header().Get("Allow")).Should(Or(Equal("POST, PUT"), Equal("PUT, POST")))
106+
Ω(string(rw.Body)).Should(MatchRegexp(`{"id":".*","code":"method_not_allowed","status":405,"detail":".*","meta":{.*}}` + "\n"))
107+
})
108+
})
109+
88110
Describe("MaxRequestBodyLength", func() {
89111
var rw *TestResponseWriter
90112
var req *http.Request

0 commit comments

Comments
 (0)