forked from goadesign/goa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmiddleware.go
320 lines (298 loc) · 9.62 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
package goa
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"regexp"
"runtime"
"strings"
"sync/atomic"
"time"
log "gopkg.in/inconshreveable/log15.v2"
"golang.org/x/net/context"
)
type (
// Middleware represents the canonical goa middleware signature.
Middleware func(Handler) Handler
)
// NewMiddleware creates a middleware from the given argument. The allowed types for the
// argument are:
//
// - a goa middleware: goa.Middleware or func(goa.Handler) goa.Handler
//
// - a goa handler: goa.Handler or func(*goa.Context) error
//
// - an http middleware: func(http.Handler) http.Handler
//
// - or an http handler: http.Handler or func(http.ResponseWriter, *http.Request)
//
// An error is returned if the given argument is not one of the types above.
func NewMiddleware(m interface{}) (mw Middleware, err error) {
switch m := m.(type) {
case Middleware:
mw = m
case func(Handler) Handler:
mw = m
case Handler:
mw = handlerToMiddleware(m)
case func(*Context) error:
mw = handlerToMiddleware(m)
case func(http.Handler) http.Handler:
mw = func(h Handler) Handler {
return func(ctx *Context) (err error) {
rw := ctx.Value(respKey).(http.ResponseWriter)
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err = h(ctx)
})).ServeHTTP(rw, ctx.Request())
return
}
}
case http.Handler:
mw = httpHandlerToMiddleware(m.ServeHTTP)
case func(http.ResponseWriter, *http.Request):
mw = httpHandlerToMiddleware(m)
default:
err = fmt.Errorf("invalid middleware %#v", m)
}
return
}
// ReqIDKey is the RequestID middleware key used to store the request ID value in the context.
const ReqIDKey middlewareKey = 0
// RequestIDHeader is the name of the header used to transmit the request ID.
const RequestIDHeader = "X-Request-Id"
// Counter used to create new request ids.
var reqID int64
// Common prefix to all newly created request ids for this process.
var reqPrefix string
// Initialize common prefix on process startup.
func init() {
// algorithm taken from https://github.com/zenazn/goji/blob/master/web/middleware/request_id.go#L44-L50
var buf [12]byte
var b64 string
for len(b64) < 10 {
rand.Read(buf[:])
b64 = base64.StdEncoding.EncodeToString(buf[:])
b64 = strings.NewReplacer("+", "", "/", "").Replace(b64)
}
reqPrefix = string(b64[0:10])
}
// middlewareKey is the private type used for goa middlewares to store values in the context.
// It is private to avoid possible collisions with keys used by other packages.
type middlewareKey int
// LogRequest creates a request logger middleware.
// This middleware is aware of the RequestID middleware and if registered after it leverages the
// request ID for logging.
func LogRequest() Middleware {
return func(h Handler) Handler {
return func(ctx *Context) error {
reqID := ctx.Value(ReqIDKey)
if reqID == nil {
reqID = shortID()
}
ctx.Logger = ctx.Logger.New("id", reqID)
startedAt := time.Now()
r := ctx.Value(reqKey).(*http.Request)
ctx.Info("started", r.Method, r.URL.String())
params := ctx.Value(paramKey).(map[string]string)
if len(params) > 0 {
logCtx := make(log.Ctx, len(params))
for k, v := range params {
logCtx[k] = interface{}(v)
}
ctx.Debug("params", logCtx)
}
query := ctx.Value(queryKey).(map[string][]string)
if len(query) > 0 {
logCtx := make(log.Ctx, len(query))
for k, v := range query {
logCtx[k] = interface{}(v)
}
ctx.Debug("query", logCtx)
}
payload := ctx.Value(payloadKey)
if r.ContentLength > 0 {
if mp, ok := payload.(map[string]interface{}); ok {
ctx.Debug("payload", log.Ctx(mp))
} else {
ctx.Debug("payload", "raw", payload)
}
}
err := h(ctx)
ctx.Info("completed", "status", ctx.ResponseStatus(),
"bytes", ctx.ResponseLength(), "time", time.Since(startedAt).String())
return err
}
}
}
// RequestID is a middleware that injects a request ID into the context of each request.
// Retrieve it using ctx.Value(ReqIDKey). If the incoming request has a RequestIDHeader header then
// that value is used else a random value is generated.
func RequestID() Middleware {
return func(h Handler) Handler {
return func(ctx *Context) error {
id := ctx.Request().Header.Get(RequestIDHeader)
if id == "" {
id = fmt.Sprintf("%s-%d", reqPrefix, atomic.AddInt64(&reqID, 1))
}
ctx.SetValue(ReqIDKey, id)
return h(ctx)
}
}
}
// Recover is a middleware that recovers panics and returns an internal error response.
func Recover() Middleware {
return func(h Handler) Handler {
return func(ctx *Context) (err error) {
defer func() {
if r := recover(); r != nil {
if ctx != nil {
switch x := r.(type) {
case string:
err = fmt.Errorf("panic: %s", x)
case error:
err = x
default:
err = errors.New("unknown panic")
}
const size = 64 << 10 // 64KB
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
lines := strings.Split(string(buf), "\n")
stack := lines[3:]
status := http.StatusInternalServerError
var message string
if ctx.Logger != nil {
reqID := ctx.Value(ReqIDKey)
if reqID != nil {
message = fmt.Sprintf(
"%s\nRefer to the following token when contacting support: %s",
http.StatusText(status),
reqID)
}
ctx.Logger.Error("panic", "err", err, "stack", stack)
}
// note we must respond or else a 500 with "unhandled request" is the
// default response.
if message == "" {
// without the logger and/or request id (from middleware) we can
// only return the full error message for reference purposes. it
// is unlikely to make sense to the caller unless they understand
// the source code.
message = err.Error()
}
ctx.Respond(status, []byte(message))
}
}
}()
return h(ctx)
}
}
}
// Timeout sets a global timeout for all controller actions.
// The timeout notification is made through the context, it is the responsability of the request
// handler to handle it. For example:
//
// func (ctrl *Controller) DoLongRunningAction(ctx *DoLongRunningActionContext) error {
// action := NewLongRunning() // setup long running action
// c := make(chan error, 1) // create return channel
// go func() { c <- action.Run() } // Launch long running action goroutine
// select {
// case <- ctx.Done(): // timeout triggered
// action.Cancel() // cancel long running action
// <-c // wait for Run to return.
// return ctx.Err() // retrieve cancel reason
// case err := <-c: // action finished on time
// return err // forward its return value
// }
// }
//
// Package golang.org/x/net/context/ctxhttp contains an implementation of an HTTP client which is
// context-aware:
//
// func (ctrl *Controller) HttpAction(ctx *HttpActionContext) error {
// req, err := http.NewRequest("GET", "http://iamaslowservice.com", nil)
// // ...
// resp, err := ctxhttp.Do(ctx, nil, req) // returns if timeout triggers
// // ...
// }
//
// Controller actions can check if a timeout is set by calling the context Deadline method.
func Timeout(timeout time.Duration) Middleware {
return func(h Handler) Handler {
return func(ctx *Context) (err error) {
// We discard the cancel function because the goa handler already takes
// care of canceling on completion.
ctx.Context, _ = context.WithTimeout(ctx.Context, timeout)
return h(ctx)
}
}
}
// RequireHeader requires a request header to match a value pattern. If the
// header is missing or does not match then the failureStatus is the response
// (e.g. http.StatusUnauthorized). If pathPattern is nil then any path is
// included. If requiredHeaderValue is nil then any value is accepted so long as
// the header is non-empty.
func RequireHeader(
pathPattern *regexp.Regexp,
requiredHeaderName string,
requiredHeaderValue *regexp.Regexp,
failureStatus int) Middleware {
return func(h Handler) Handler {
return func(ctx *Context) (err error) {
if pathPattern == nil || pathPattern.MatchString(ctx.Request().URL.Path) {
matched := false
header := ctx.Request().Header
headerValue := header.Get(requiredHeaderName)
if len(headerValue) > 0 {
if requiredHeaderValue == nil {
matched = true
} else {
matched = requiredHeaderValue.MatchString(headerValue)
}
}
if matched {
err = h(ctx)
} else {
err = ctx.Respond(failureStatus, []byte(http.StatusText(failureStatus)))
}
} else {
err = h(ctx)
}
return
}
}
}
// shortID produces a "unique" 6 bytes long string.
// Do not use as a reliable way to get unique IDs, instead use for things like logging.
func shortID() string {
b := make([]byte, 6)
io.ReadFull(rand.Reader, b)
return base64.StdEncoding.EncodeToString(b)
}
// handlerToMiddleware creates a middleware from a raw handler.
// The middleware calls the handler and either breaks the middleware chain if the handler returns
// an error by also returning the error or calls the next handler in the chain otherwise.
func handlerToMiddleware(m Handler) Middleware {
return func(h Handler) Handler {
return func(ctx *Context) error {
if err := m(ctx); err != nil {
return err
}
return h(ctx)
}
}
}
// httpHandlerToMiddleware creates a middleware from a http.HandlerFunc.
// The middleware calls the ServerHTTP method exposed by the http handler and then calls the next
// middleware in the chain.
func httpHandlerToMiddleware(m http.HandlerFunc) Middleware {
return func(h Handler) Handler {
return func(ctx *Context) error {
m.ServeHTTP(ctx, ctx.Request())
return h(ctx)
}
}
}