Skip to content

Commit

Permalink
Rewrite AutomaticOptions to not use httptest
Browse files Browse the repository at this point in the history
httptest was adding an extra flag, which was sort of ugly. Instead,
reimplement the parts of its functionality we were using. Bonus: due to
specialization, it's now a bit more efficient as well!
  • Loading branch information
zenazn committed Apr 21, 2014
1 parent 7b91ca1 commit bc3ac1d
Showing 1 changed file with 54 additions and 31 deletions.
85 changes: 54 additions & 31 deletions web/middleware/options.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,73 @@
package middleware

import (
"io"
"net/http"
"net/http/httptest"
"strings"

"github.com/zenazn/goji/web"
)

// AutomaticOptions automatically return an appropriate "Allow" header when the
// request method is OPTIONS and the request would have otherwise been 404'd.
func AutomaticOptions(c *web.C, h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
// This will probably slow down OPTIONS calls a bunch, but it
// probably won't happen too much, and it'll just be hitting the
// 404 route anyways.
var fw *httptest.ResponseRecorder
pw := w
if r.Method == "OPTIONS" {
fw = httptest.NewRecorder()
pw = fw
}
type autoOptionsState int

h.ServeHTTP(pw, r)
const (
aosInit autoOptionsState = iota
aosHeaderWritten
aosProxying
)

if fw == nil {
return
}
// I originally used an httptest.ResponseRecorder here, but package httptest
// adds a flag which I'm not particularly eager to expose. This is essentially a
// ResponseRecorder that has been specialized for the purpose at hand to avoid
// the httptest dependency.
type autoOptionsProxy struct {
w http.ResponseWriter
c *web.C
state autoOptionsState
}

func (p *autoOptionsProxy) Header() http.Header {
return p.w.Header()
}

for k, v := range fw.Header() {
w.Header()[k] = v
func (p *autoOptionsProxy) Write(buf []byte) (int, error) {
switch p.state {
case aosInit:
p.state = aosHeaderWritten
case aosProxying:
return len(buf), nil
}
return p.w.Write(buf)
}

func (p *autoOptionsProxy) WriteHeader(code int) {
methods := getValidMethods(*p.c)
switch p.state {
case aosInit:
if methods != nil && code == http.StatusNotFound {
p.state = aosProxying
break
}
p.state = aosHeaderWritten
fallthrough
default:
p.w.WriteHeader(code)
return
}

methods := getValidMethods(*c)
methods = addMethod(methods, "OPTIONS")
p.w.Header().Set("Allow", strings.Join(methods, ", "))
p.w.WriteHeader(http.StatusOK)
}

if fw.Code == http.StatusNotFound && methods != nil {
methods = addMethod(methods, "OPTIONS")
w.Header().Set("Allow", strings.Join(methods, ", "))
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(fw.Code)
io.Copy(w, fw.Body)
// AutomaticOptions automatically return an appropriate "Allow" header when the
// request method is OPTIONS and the request would have otherwise been 404'd.
func AutomaticOptions(c *web.C, h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
w = &autoOptionsProxy{c: c, w: w}
}

h.ServeHTTP(w, r)
}

return http.HandlerFunc(fn)
Expand All @@ -62,8 +87,6 @@ func getValidMethods(c web.C) []string {
return nil
}

// Assumption: the list of methods is teensy, and that anything we could
// possibly want to do here is going to be fast.
func addMethod(methods []string, method string) []string {
for _, m := range methods {
if m == method {
Expand Down

0 comments on commit bc3ac1d

Please sign in to comment.