From bfaa7f776370f7750ceaed07e0a0667a0d2be600 Mon Sep 17 00:00:00 2001 From: Joshua Kim <20001595+joshua-kim@users.noreply.github.com> Date: Thu, 1 Jun 2023 12:29:44 -0400 Subject: [PATCH] Add allowed http hosts configuration (#1566) --- api/server/allowed_hosts.go | 76 +++++++++++++++++++++++++++++++ api/server/allowed_hosts_test.go | 77 ++++++++++++++++++++++++++++++++ api/server/server.go | 4 +- config/config.go | 1 + config/flags.go | 1 + config/keys.go | 1 + node/config.go | 1 + node/node.go | 2 + 8 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 api/server/allowed_hosts.go create mode 100644 api/server/allowed_hosts_test.go diff --git a/api/server/allowed_hosts.go b/api/server/allowed_hosts.go new file mode 100644 index 000000000000..6745f0e17565 --- /dev/null +++ b/api/server/allowed_hosts.go @@ -0,0 +1,76 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package server + +import ( + "net" + "net/http" + "strings" + + "github.com/ava-labs/avalanchego/utils/set" +) + +const wildcard = "*" + +var _ http.Handler = (*allowedHostsHandler)(nil) + +func filterInvalidHosts( + handler http.Handler, + allowed []string, +) http.Handler { + s := set.Set[string]{} + + for _, host := range allowed { + if host == wildcard { + // wildcards match all hostnames, so just return the base handler + return handler + } + s.Add(strings.ToLower(host)) + } + + return &allowedHostsHandler{ + handler: handler, + hosts: s, + } +} + +// allowedHostsHandler is an implementation of http.Handler that validates the +// http host header of incoming requests. This can prevent DNS rebinding attacks +// which do not utilize CORS-headers. Http request host headers are validated +// against a whitelist to determine whether the request should be dropped or +// not. +type allowedHostsHandler struct { + handler http.Handler + hosts set.Set[string] +} + +func (a *allowedHostsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // if the host header is missing we can serve this request because dns + // rebinding attacks rely on this header + if r.Host == "" { + a.handler.ServeHTTP(w, r) + return + } + + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + // either invalid (too many colons) or no port specified + host = r.Host + } + + if ipAddr := net.ParseIP(host); ipAddr != nil { + // accept requests from ips + a.handler.ServeHTTP(w, r) + return + } + + // a specific hostname - we need to check the whitelist to see if we should + // accept this r + if a.hosts.Contains(strings.ToLower(host)) { + a.handler.ServeHTTP(w, r) + return + } + + http.Error(w, "invalid host specified", http.StatusForbidden) +} diff --git a/api/server/allowed_hosts_test.go b/api/server/allowed_hosts_test.go new file mode 100644 index 000000000000..ae7a824834a9 --- /dev/null +++ b/api/server/allowed_hosts_test.go @@ -0,0 +1,77 @@ +// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAllowedHostsHandler_ServeHTTP(t *testing.T) { + tests := []struct { + name string + allowed []string + host string + serve bool + }{ + { + name: "no host header", + allowed: []string{"www.foobar.com"}, + host: "", + serve: true, + }, + { + name: "ip", + allowed: []string{"www.foobar.com"}, + host: "192.168.1.1", + serve: true, + }, + { + name: "hostname not allowed", + allowed: []string{"www.foobar.com"}, + host: "www.evil.com", + }, + { + name: "hostname allowed", + allowed: []string{"www.foobar.com"}, + host: "www.foobar.com", + serve: true, + }, + { + name: "wildcard", + allowed: []string{"*"}, + host: "www.foobar.com", + serve: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + baseHandler := &testHandler{} + + httpAllowedHostsHandler := filterInvalidHosts( + baseHandler, + test.allowed, + ) + + w := &httptest.ResponseRecorder{} + r := httptest.NewRequest("", "/", nil) + r.Host = test.host + + httpAllowedHostsHandler.ServeHTTP(w, r) + + if test.serve { + require.True(baseHandler.called) + return + } + + require.Equal(http.StatusForbidden, w.Code) + }) + } +} diff --git a/api/server/server.go b/api/server/server.go index 890351a90475..c3f384b28ead 100644 --- a/api/server/server.go +++ b/api/server/server.go @@ -119,6 +119,7 @@ func New( namespace string, registerer prometheus.Registerer, httpConfig HTTPConfig, + allowedHosts []string, wrappers ...Wrapper, ) (Server, error) { m, err := newMetrics(namespace, registerer) @@ -127,10 +128,11 @@ func New( } router := newRouter() + allowedHostsHandler := filterInvalidHosts(router, allowedHosts) corsHandler := cors.New(cors.Options{ AllowedOrigins: allowedOrigins, AllowCredentials: true, - }).Handler(router) + }).Handler(allowedHostsHandler) gzipHandler := gziphandler.GzipHandler(corsHandler) var handler http.Handler = http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { diff --git a/config/config.go b/config/config.go index 457344ccc85c..6acd1edbb509 100644 --- a/config/config.go +++ b/config/config.go @@ -244,6 +244,7 @@ func getHTTPConfig(v *viper.Viper) (node.HTTPConfig, error) { HTTPSKey: httpsKey, HTTPSCert: httpsCert, APIAllowedOrigins: v.GetStringSlice(HTTPAllowedOrigins), + HTTPAllowedHosts: v.GetStringSlice(HTTPAllowedHostsKey), ShutdownTimeout: v.GetDuration(HTTPShutdownTimeoutKey), ShutdownWait: v.GetDuration(HTTPShutdownWaitKey), } diff --git a/config/flags.go b/config/flags.go index 8665184d3a11..fb4fee72e01d 100644 --- a/config/flags.go +++ b/config/flags.go @@ -221,6 +221,7 @@ func addNodeFlags(fs *pflag.FlagSet) { fs.String(HTTPSCertFileKey, "", fmt.Sprintf("TLS certificate file for the HTTPs server. Ignored if %s is specified", HTTPSCertContentKey)) fs.String(HTTPSCertContentKey, "", "Specifies base64 encoded TLS certificate for the HTTPs server") fs.String(HTTPAllowedOrigins, "*", "Origins to allow on the HTTP port. Defaults to * which allows all origins. Example: https://*.avax.network https://*.avax-test.network") + fs.StringSlice(HTTPAllowedHostsKey, []string{"localhost"}, "List of acceptable host names in API requests. Provide the wildcard ('*') to accept requests from all hosts. API requests where the Host field is empty or an IP address will always be accepted. An API call whose HTTP Host field isn't acceptable will receive a 403 error code") fs.Duration(HTTPShutdownWaitKey, 0, "Duration to wait after receiving SIGTERM or SIGINT before initiating shutdown. The /health endpoint will return unhealthy during this duration") fs.Duration(HTTPShutdownTimeoutKey, 10*time.Second, "Maximum duration to wait for existing connections to complete during node shutdown") fs.Duration(HTTPReadTimeoutKey, 30*time.Second, "Maximum duration for reading the entire request, including the body. A zero or negative value means there will be no timeout") diff --git a/config/keys.go b/config/keys.go index 788702245675..4f611af68044 100644 --- a/config/keys.go +++ b/config/keys.go @@ -54,6 +54,7 @@ const ( HTTPSCertFileKey = "http-tls-cert-file" HTTPSCertContentKey = "http-tls-cert-file-content" HTTPAllowedOrigins = "http-allowed-origins" + HTTPAllowedHostsKey = "http-allowed-hosts" HTTPShutdownTimeoutKey = "http-shutdown-timeout" HTTPShutdownWaitKey = "http-shutdown-wait" HTTPReadTimeoutKey = "http-read-timeout" diff --git a/node/config.go b/node/config.go index d6ad4f6bc84c..cdfc84cab6e3 100644 --- a/node/config.go +++ b/node/config.go @@ -54,6 +54,7 @@ type HTTPConfig struct { HTTPSCert []byte `json:"-"` APIAllowedOrigins []string `json:"apiAllowedOrigins"` + HTTPAllowedHosts []string `json:"httpAllowedHosts"` ShutdownTimeout time.Duration `json:"shutdownTimeout"` ShutdownWait time.Duration `json:"shutdownWait"` diff --git a/node/node.go b/node/node.go index 9b8686b3add1..64d035d163a0 100644 --- a/node/node.go +++ b/node/node.go @@ -596,6 +596,7 @@ func (n *Node) initAPIServer() error { "api", n.MetricsRegisterer, n.Config.HTTPConfig.HTTPConfig, + n.Config.HTTPAllowedHosts, ) return err } @@ -618,6 +619,7 @@ func (n *Node) initAPIServer() error { "api", n.MetricsRegisterer, n.Config.HTTPConfig.HTTPConfig, + n.Config.HTTPAllowedHosts, a, ) if err != nil {