Skip to content

Commit

Permalink
Added Context#IsWebSocket(), proxy fix header
Browse files Browse the repository at this point in the history
Signed-off-by: Vishal Rana <[email protected]>
  • Loading branch information
vishr committed Jun 4, 2017
1 parent 353a2f8 commit c3887eb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 38 deletions.
16 changes: 15 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type (
// IsTLS returns true if HTTP connection is TLS otherwise false.
IsTLS() bool

// IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
IsWebSocket() bool

// Scheme returns the HTTP protocol scheme, `http` or `https`.
Scheme() string

Expand Down Expand Up @@ -219,6 +222,11 @@ func (c *context) IsTLS() bool {
return c.request.TLS != nil
}

func (c *context) IsWebSocket() bool {
upgrade := c.request.Header.Get(HeaderUpgrade)
return upgrade == "websocket" || upgrade == "Websocket"
}

func (c *context) Scheme() string {
// Can't use `r.Request.URL.Scheme`
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
Expand All @@ -227,10 +235,16 @@ func (c *context) Scheme() string {
}
if scheme := c.request.Header.Get(HeaderXForwardedProto); scheme != "" {
return scheme
}
}
if scheme := c.request.Header.Get(HeaderXForwardedProtocol); scheme != "" {
return scheme
}
if ssl := c.request.Header.Get(HeaderXForwardedSsl); ssl == "on" {
return "https"
}
if scheme := c.request.Header.Get(HeaderXUrlScheme); scheme != "" {
return scheme
}
return "http"
}

Expand Down
52 changes: 28 additions & 24 deletions echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,30 +165,34 @@ const (

// Headers
const (
HeaderAccept = "Accept"
HeaderAcceptEncoding = "Accept-Encoding"
HeaderAllow = "Allow"
HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition"
HeaderContentEncoding = "Content-Encoding"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderCookie = "Cookie"
HeaderSetCookie = "Set-Cookie"
HeaderIfModifiedSince = "If-Modified-Since"
HeaderLastModified = "Last-Modified"
HeaderLocation = "Location"
HeaderUpgrade = "Upgrade"
HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderXForwardedProto = "X-Forwarded-Proto"
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXForwardedSsl = "X-Forwarded-Ssl"
HeaderXRealIP = "X-Real-IP"
HeaderXRequestID = "X-Request-ID"
HeaderServer = "Server"
HeaderOrigin = "Origin"
HeaderAccept = "Accept"
HeaderAcceptEncoding = "Accept-Encoding"
HeaderAllow = "Allow"
HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition"
HeaderContentEncoding = "Content-Encoding"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderCookie = "Cookie"
HeaderSetCookie = "Set-Cookie"
HeaderIfModifiedSince = "If-Modified-Since"
HeaderLastModified = "Last-Modified"
HeaderLocation = "Location"
HeaderUpgrade = "Upgrade"
HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXForwardedProto = "X-Forwarded-Proto"
HeaderXForwardedProtocol = "X-Forwarded-Protocol"
HeaderXForwardedSsl = "X-Forwarded-Ssl"
HeaderXUrlScheme = "X-Url-Scheme"
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
HeaderXRealIP = "X-Real-IP"
HeaderXRequestID = "X-Request-ID"
HeaderServer = "Server"
HeaderOrigin = "Origin"

// Access control
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
Expand Down
35 changes: 22 additions & 13 deletions middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"github.com/labstack/echo"
)

// TODO: Handle TLS proxy

type (
// ProxyConfig defines the config for Proxy middleware.
ProxyConfig struct {
Expand Down Expand Up @@ -63,25 +65,24 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
c.Error(errors.New("proxy raw, not a hijacker"))
return
}

in, _, err := h.Hijack()
if err != nil {
c.Error(fmt.Errorf("proxy raw hijack error=%v, url=%s", r.URL, err))
c.Error(fmt.Errorf("proxy raw, hijack error=%v, url=%s", r.URL, err))
return
}
defer in.Close()

out, err := net.Dial("tcp", t.URL.Host)
if err != nil {
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw dial error=%v, url=%s", r.URL, err))
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", r.URL, err))
c.Error(he)
return
}
defer out.Close()

err = r.Write(out)
if err != nil {
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw request copy error=%v, url=%s", r.URL, err))
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request copy error=%v, url=%s", r.URL, err))
c.Error(he)
return
}
Expand All @@ -96,7 +97,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
go cp(in, out)
err = <-errc
if err != nil && err != io.EOF {
c.Logger().Errorf("proxy raw error=%v, url=%s", r.URL, err)
c.Logger().Errorf("proxy raw, error=%v, url=%s", r.URL, err)
}
})
}
Expand Down Expand Up @@ -131,18 +132,26 @@ func Proxy(config ProxyConfig) echo.MiddlewareFunc {
return func(c echo.Context) (err error) {
req := c.Request()
res := c.Response()
t := config.Balancer.Next()
tgt := config.Balancer.Next()

// Proxy
upgrade := req.Header.Get(echo.HeaderUpgrade)
accept := req.Header.Get(echo.HeaderAccept)
// Fix header
if req.Header.Get(echo.HeaderXRealIP) == "" {
req.Header.Set(echo.HeaderXRealIP, c.RealIP())
}
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
}
if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}

// Proxy
switch {
case upgrade == "websocket" || upgrade == "Websocket":
proxyRaw(t, c).ServeHTTP(res, req)
case accept == "text/event-stream":
case c.IsWebSocket():
proxyRaw(tgt, c).ServeHTTP(res, req)
case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
default:
proxyHTTP(t).ServeHTTP(res, req)
proxyHTTP(tgt).ServeHTTP(res, req)
}

return
Expand Down

0 comments on commit c3887eb

Please sign in to comment.