diff --git a/web/middleware/options.go b/web/middleware/options.go index 4bb7b12..4bdce5f 100644 --- a/web/middleware/options.go +++ b/web/middleware/options.go @@ -1,48 +1,73 @@ package middleware import ( - "io" "net/http" - "net/http/httptest" "strings" "github.com/zenazn/goji/web" ) -// AutomaticOptions automatically return an appropriate "Allow" header when the -// request method is OPTIONS and the request would have otherwise been 404'd. -func AutomaticOptions(c *web.C, h http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - // This will probably slow down OPTIONS calls a bunch, but it - // probably won't happen too much, and it'll just be hitting the - // 404 route anyways. - var fw *httptest.ResponseRecorder - pw := w - if r.Method == "OPTIONS" { - fw = httptest.NewRecorder() - pw = fw - } +type autoOptionsState int - h.ServeHTTP(pw, r) +const ( + aosInit autoOptionsState = iota + aosHeaderWritten + aosProxying +) - if fw == nil { - return - } +// I originally used an httptest.ResponseRecorder here, but package httptest +// adds a flag which I'm not particularly eager to expose. This is essentially a +// ResponseRecorder that has been specialized for the purpose at hand to avoid +// the httptest dependency. +type autoOptionsProxy struct { + w http.ResponseWriter + c *web.C + state autoOptionsState +} + +func (p *autoOptionsProxy) Header() http.Header { + return p.w.Header() +} - for k, v := range fw.Header() { - w.Header()[k] = v +func (p *autoOptionsProxy) Write(buf []byte) (int, error) { + switch p.state { + case aosInit: + p.state = aosHeaderWritten + case aosProxying: + return len(buf), nil + } + return p.w.Write(buf) +} + +func (p *autoOptionsProxy) WriteHeader(code int) { + methods := getValidMethods(*p.c) + switch p.state { + case aosInit: + if methods != nil && code == http.StatusNotFound { + p.state = aosProxying + break } + p.state = aosHeaderWritten + fallthrough + default: + p.w.WriteHeader(code) + return + } - methods := getValidMethods(*c) + methods = addMethod(methods, "OPTIONS") + p.w.Header().Set("Allow", strings.Join(methods, ", ")) + p.w.WriteHeader(http.StatusOK) +} - if fw.Code == http.StatusNotFound && methods != nil { - methods = addMethod(methods, "OPTIONS") - w.Header().Set("Allow", strings.Join(methods, ", ")) - w.WriteHeader(http.StatusOK) - } else { - w.WriteHeader(fw.Code) - io.Copy(w, fw.Body) +// AutomaticOptions automatically return an appropriate "Allow" header when the +// request method is OPTIONS and the request would have otherwise been 404'd. +func AutomaticOptions(c *web.C, h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if r.Method == "OPTIONS" { + w = &autoOptionsProxy{c: c, w: w} } + + h.ServeHTTP(w, r) } return http.HandlerFunc(fn) @@ -62,8 +87,6 @@ func getValidMethods(c web.C) []string { return nil } -// Assumption: the list of methods is teensy, and that anything we could -// possibly want to do here is going to be fast. func addMethod(methods []string, method string) []string { for _, m := range methods { if m == method {