Skip to content

Commit 3213297

Browse files
author
Raphaël Simon
committed
Remove need for "finalizing" service. (goadesign#493)
Instead use closures to lazily compute middleware chains. This makes it possible to mount controllers and middlewares in any order.
1 parent 77f479a commit 3213297

File tree

2 files changed

+71
-65
lines changed

2 files changed

+71
-65
lines changed

context.go

+14
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ const (
1717
paramsKey
1818
logKey
1919
logContextKey
20+
errKey
2021
securityScopesKey
2122
)
2223

@@ -83,6 +84,11 @@ func WithLogContext(ctx context.Context, keyvals ...interface{}) context.Context
8384
return WithLogger(ctx, nl)
8485
}
8586

87+
// WithError creates a context with the given error.
88+
func WithError(ctx context.Context, err error) context.Context {
89+
return context.WithValue(ctx, errKey, err)
90+
}
91+
8692
// ContextController extracts the controller name from the given context.
8793
func ContextController(ctx context.Context) string {
8894
if c := ctx.Value(ctrlKey); c != nil {
@@ -123,6 +129,14 @@ func ContextLogger(ctx context.Context) LogAdapter {
123129
return nil
124130
}
125131

132+
// ContextError extracts the error from the given context.
133+
func ContextError(ctx context.Context) error {
134+
if err := ctx.Value(errKey); err != nil {
135+
return err.(error)
136+
}
137+
return nil
138+
}
139+
126140
// SwitchWriter overrides the underlying response writer. It returns the response
127141
// writer that was previously set.
128142
func (r *ResponseData) SwitchWriter(rw http.ResponseWriter) http.ResponseWriter {

service.go

+57-65
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ type (
4242
// available to all request handlers.
4343
Context context.Context
4444

45-
finalized bool // Whether controllers have been mounted
4645
middleware []Middleware // Middleware chain
47-
notFound Handler // Handler of requests that don't match registered mux handlers
4846
cancel context.CancelFunc // Service context cancel signal trigger
4947
decoderPools map[string]*decoderPool // Registered decoders for the service
5048
encoderPools map[string]*encoderPool // Registered encoders for the service
@@ -53,9 +51,12 @@ type (
5351

5452
// Controller defines the common fields and behavior of generated controllers.
5553
Controller struct {
56-
Name string // Controller resource name
57-
Service *Service // Service that exposes the controller
58-
Context context.Context // Controller root context
54+
// Controller resource name
55+
Name string
56+
// Service that exposes the controller
57+
Service *Service
58+
// Controller root context
59+
Context context.Context
5960

6061
middleware []Middleware // Controller specific middleware if any
6162
}
@@ -92,15 +93,26 @@ func New(name string) *Service {
9293
decoderPools: map[string]*decoderPool{},
9394
encoderPools: map[string]*encoderPool{},
9495
encodableContentTypes: []string{},
95-
notFound: func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
96-
return ErrNotFound(req.URL.Path)
97-
},
9896
}
97+
notFoundHandler Handler
9998
)
10099

100+
// Setup default NotFound handler
101101
mux.HandleNotFound(func(rw http.ResponseWriter, req *http.Request, params url.Values) {
102+
// Use closure to do lazy computation of middleware chain so all middlewares are
103+
// registered.
104+
if notFoundHandler == nil {
105+
notFoundHandler = func(_ context.Context, _ http.ResponseWriter, req *http.Request) error {
106+
return ErrNotFound(req.URL.Path)
107+
}
108+
chain := service.middleware
109+
ml := len(chain)
110+
for i := range chain {
111+
notFoundHandler = chain[ml-i-1](notFoundHandler)
112+
}
113+
}
102114
ctx := NewContext(service.Context, rw, req, params)
103-
err := service.notFound(ctx, rw, req)
115+
err := notFoundHandler(ctx, ContextResponse(ctx), req)
104116
if !ContextResponse(ctx).Written() {
105117
service.Send(ctx, 404, err)
106118
}
@@ -119,9 +131,6 @@ func (service *Service) CancelAll() {
119131
// goa comes with a set of commonly used middleware, see the middleware package.
120132
// Controller specific middleware should be mounted using the Controller struct Use method instead.
121133
func (service *Service) Use(m Middleware) {
122-
if service.finalized {
123-
panic("goa: cannot mount middleware after controller")
124-
}
125134
service.middleware = append(service.middleware, m)
126135
}
127136

@@ -179,27 +188,6 @@ func (service *Service) ServeFiles(path, filename string) error {
179188
return ctrl.ServeFiles(path, filename)
180189
}
181190

182-
// finalize wraps the NotFound handler with the final middleware chain.
183-
// Use cannot be called after finalize has.
184-
func (service *Service) finalize() {
185-
if service.finalized {
186-
return
187-
}
188-
notFound := service.notFound
189-
handler := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
190-
if !ContextResponse(ctx).Written() {
191-
return notFound(ctx, rw, req)
192-
}
193-
return nil
194-
}
195-
ml := len(service.middleware)
196-
for i := range service.middleware {
197-
handler = service.middleware[ml-i-1](handler)
198-
}
199-
service.notFound = handler
200-
service.finalized = true
201-
}
202-
203191
// ServeFiles replies to the request with the contents of the named file or directory. See
204192
// FileHandler for details.
205193
func (ctrl *Controller) ServeFiles(path, filename string) error {
@@ -220,9 +208,6 @@ func (ctrl *Controller) ServeFiles(path, filename string) error {
220208
// Use adds a middleware to the controller.
221209
// Service-wide middleware should be added via the Service Use method instead.
222210
func (ctrl *Controller) Use(m Middleware) {
223-
if ctrl.Service.finalized {
224-
panic("goa: cannot mount middleware after controller")
225-
}
226211
ctrl.middleware = append(ctrl.middleware, m)
227212
}
228213

@@ -232,22 +217,41 @@ func (ctrl *Controller) Use(m Middleware) {
232217
// This function is intended for the controller generated code. User code should not need to call
233218
// it directly.
234219
func (ctrl *Controller) MuxHandler(name string, hdlr Handler, unm Unmarshaler) MuxHandler {
235-
// Make sure middleware doesn't get mounted later
236-
ctrl.Service.finalize()
220+
// Use closure to enable late computation of handlers to ensure all middleware has been
221+
// registered.
222+
var handler, invalidPayloadHandler Handler
237223

238-
// Setup middleware outside of closure
239-
middleware := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
240-
if !ContextResponse(ctx).Written() {
241-
return hdlr(ctx, rw, req)
242-
}
243-
return nil
244-
}
245-
chain := append(ctrl.Service.middleware, ctrl.middleware...)
246-
ml := len(chain)
247-
for i := range chain {
248-
middleware = chain[ml-i-1](middleware)
249-
}
250224
return func(rw http.ResponseWriter, req *http.Request, params url.Values) {
225+
// Build handler middleware chains on first invocation
226+
if handler == nil {
227+
handler = func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
228+
if !ContextResponse(ctx).Written() {
229+
return hdlr(ctx, rw, req)
230+
}
231+
return nil
232+
}
233+
invalidPayloadHandler = func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
234+
rw.Header().Set("Content-Type", ErrorMediaIdentifier)
235+
status := 400
236+
err := ContextError(ctx)
237+
if err == nil {
238+
err = fmt.Errorf("unknown error")
239+
}
240+
body := ErrInvalidEncoding(err)
241+
if err.Error() == "http: request body too large" {
242+
status = 413
243+
body = ErrRequestBodyTooLarge("body length exceeds %d bytes", MaxRequestBodyLength)
244+
}
245+
return ctrl.Service.Send(ctx, status, body)
246+
}
247+
chain := append(ctrl.Service.middleware, ctrl.middleware...)
248+
ml := len(chain)
249+
for i := range chain {
250+
handler = chain[ml-i-1](handler)
251+
invalidPayloadHandler = chain[ml-i-1](invalidPayloadHandler)
252+
}
253+
}
254+
251255
// Build context
252256
ctx := NewContext(WithAction(ctrl.Context, name), rw, req, params)
253257

@@ -263,24 +267,12 @@ func (ctrl *Controller) MuxHandler(name string, hdlr Handler, unm Unmarshaler) M
263267
}
264268

265269
// Handle invalid payload
266-
handler := middleware
267270
if err != nil {
268-
handler = func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
269-
rw.Header().Set("Content-Type", ErrorMediaIdentifier)
270-
status := 400
271-
body := ErrInvalidEncoding(err)
272-
if err.Error() == "http: request body too large" {
273-
status = 413
274-
body = ErrRequestBodyTooLarge("body length exceeds %d bytes", MaxRequestBodyLength)
275-
}
276-
return ctrl.Service.Send(ctx, status, body)
277-
}
278-
for i := range chain {
279-
handler = chain[ml-i-1](handler)
280-
}
271+
ctx = WithError(ctx, err)
272+
handler = invalidPayloadHandler
281273
}
282274

283-
// Invoke middleware chain, errors should be caught earlier, e.g. by ErrorHandler middleware
275+
// Invoke handler
284276
if err := handler(ctx, ContextResponse(ctx), req); err != nil {
285277
LogError(ctx, "uncaught error", "err", err)
286278
respBody := fmt.Sprintf("Internal error: %s", err) // Sprintf catches panics

0 commit comments

Comments
 (0)