Skip to content

Commit

Permalink
Add allowed http hosts configuration (ava-labs#1566)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-kim authored Jun 1, 2023
1 parent 8fb8afe commit bfaa7f7
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 1 deletion.
76 changes: 76 additions & 0 deletions api/server/allowed_hosts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package server

import (
"net"
"net/http"
"strings"

"github.com/ava-labs/avalanchego/utils/set"
)

const wildcard = "*"

var _ http.Handler = (*allowedHostsHandler)(nil)

func filterInvalidHosts(
handler http.Handler,
allowed []string,
) http.Handler {
s := set.Set[string]{}

for _, host := range allowed {
if host == wildcard {
// wildcards match all hostnames, so just return the base handler
return handler
}
s.Add(strings.ToLower(host))
}

return &allowedHostsHandler{
handler: handler,
hosts: s,
}
}

// allowedHostsHandler is an implementation of http.Handler that validates the
// http host header of incoming requests. This can prevent DNS rebinding attacks
// which do not utilize CORS-headers. Http request host headers are validated
// against a whitelist to determine whether the request should be dropped or
// not.
type allowedHostsHandler struct {
handler http.Handler
hosts set.Set[string]
}

func (a *allowedHostsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// if the host header is missing we can serve this request because dns
// rebinding attacks rely on this header
if r.Host == "" {
a.handler.ServeHTTP(w, r)
return
}

host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// either invalid (too many colons) or no port specified
host = r.Host
}

if ipAddr := net.ParseIP(host); ipAddr != nil {
// accept requests from ips
a.handler.ServeHTTP(w, r)
return
}

// a specific hostname - we need to check the whitelist to see if we should
// accept this r
if a.hosts.Contains(strings.ToLower(host)) {
a.handler.ServeHTTP(w, r)
return
}

http.Error(w, "invalid host specified", http.StatusForbidden)
}
77 changes: 77 additions & 0 deletions api/server/allowed_hosts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package server

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/require"
)

func TestAllowedHostsHandler_ServeHTTP(t *testing.T) {
tests := []struct {
name string
allowed []string
host string
serve bool
}{
{
name: "no host header",
allowed: []string{"www.foobar.com"},
host: "",
serve: true,
},
{
name: "ip",
allowed: []string{"www.foobar.com"},
host: "192.168.1.1",
serve: true,
},
{
name: "hostname not allowed",
allowed: []string{"www.foobar.com"},
host: "www.evil.com",
},
{
name: "hostname allowed",
allowed: []string{"www.foobar.com"},
host: "www.foobar.com",
serve: true,
},
{
name: "wildcard",
allowed: []string{"*"},
host: "www.foobar.com",
serve: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require := require.New(t)

baseHandler := &testHandler{}

httpAllowedHostsHandler := filterInvalidHosts(
baseHandler,
test.allowed,
)

w := &httptest.ResponseRecorder{}
r := httptest.NewRequest("", "/", nil)
r.Host = test.host

httpAllowedHostsHandler.ServeHTTP(w, r)

if test.serve {
require.True(baseHandler.called)
return
}

require.Equal(http.StatusForbidden, w.Code)
})
}
}
4 changes: 3 additions & 1 deletion api/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func New(
namespace string,
registerer prometheus.Registerer,
httpConfig HTTPConfig,
allowedHosts []string,
wrappers ...Wrapper,
) (Server, error) {
m, err := newMetrics(namespace, registerer)
Expand All @@ -127,10 +128,11 @@ func New(
}

router := newRouter()
allowedHostsHandler := filterInvalidHosts(router, allowedHosts)
corsHandler := cors.New(cors.Options{
AllowedOrigins: allowedOrigins,
AllowCredentials: true,
}).Handler(router)
}).Handler(allowedHostsHandler)
gzipHandler := gziphandler.GzipHandler(corsHandler)
var handler http.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
Expand Down
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ func getHTTPConfig(v *viper.Viper) (node.HTTPConfig, error) {
HTTPSKey: httpsKey,
HTTPSCert: httpsCert,
APIAllowedOrigins: v.GetStringSlice(HTTPAllowedOrigins),
HTTPAllowedHosts: v.GetStringSlice(HTTPAllowedHostsKey),
ShutdownTimeout: v.GetDuration(HTTPShutdownTimeoutKey),
ShutdownWait: v.GetDuration(HTTPShutdownWaitKey),
}
Expand Down
1 change: 1 addition & 0 deletions config/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ func addNodeFlags(fs *pflag.FlagSet) {
fs.String(HTTPSCertFileKey, "", fmt.Sprintf("TLS certificate file for the HTTPs server. Ignored if %s is specified", HTTPSCertContentKey))
fs.String(HTTPSCertContentKey, "", "Specifies base64 encoded TLS certificate for the HTTPs server")
fs.String(HTTPAllowedOrigins, "*", "Origins to allow on the HTTP port. Defaults to * which allows all origins. Example: https://*.avax.network https://*.avax-test.network")
fs.StringSlice(HTTPAllowedHostsKey, []string{"localhost"}, "List of acceptable host names in API requests. Provide the wildcard ('*') to accept requests from all hosts. API requests where the Host field is empty or an IP address will always be accepted. An API call whose HTTP Host field isn't acceptable will receive a 403 error code")
fs.Duration(HTTPShutdownWaitKey, 0, "Duration to wait after receiving SIGTERM or SIGINT before initiating shutdown. The /health endpoint will return unhealthy during this duration")
fs.Duration(HTTPShutdownTimeoutKey, 10*time.Second, "Maximum duration to wait for existing connections to complete during node shutdown")
fs.Duration(HTTPReadTimeoutKey, 30*time.Second, "Maximum duration for reading the entire request, including the body. A zero or negative value means there will be no timeout")
Expand Down
1 change: 1 addition & 0 deletions config/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ const (
HTTPSCertFileKey = "http-tls-cert-file"
HTTPSCertContentKey = "http-tls-cert-file-content"
HTTPAllowedOrigins = "http-allowed-origins"
HTTPAllowedHostsKey = "http-allowed-hosts"
HTTPShutdownTimeoutKey = "http-shutdown-timeout"
HTTPShutdownWaitKey = "http-shutdown-wait"
HTTPReadTimeoutKey = "http-read-timeout"
Expand Down
1 change: 1 addition & 0 deletions node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type HTTPConfig struct {
HTTPSCert []byte `json:"-"`

APIAllowedOrigins []string `json:"apiAllowedOrigins"`
HTTPAllowedHosts []string `json:"httpAllowedHosts"`

ShutdownTimeout time.Duration `json:"shutdownTimeout"`
ShutdownWait time.Duration `json:"shutdownWait"`
Expand Down
2 changes: 2 additions & 0 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ func (n *Node) initAPIServer() error {
"api",
n.MetricsRegisterer,
n.Config.HTTPConfig.HTTPConfig,
n.Config.HTTPAllowedHosts,
)
return err
}
Expand All @@ -618,6 +619,7 @@ func (n *Node) initAPIServer() error {
"api",
n.MetricsRegisterer,
n.Config.HTTPConfig.HTTPConfig,
n.Config.HTTPAllowedHosts,
a,
)
if err != nil {
Expand Down

0 comments on commit bfaa7f7

Please sign in to comment.