Skip to content

Commit

Permalink
Use Dial with context
Browse files Browse the repository at this point in the history
  • Loading branch information
ash2k committed May 18, 2018
1 parent 77a08ee commit 5e8e570
Show file tree
Hide file tree
Showing 25 changed files with 111 additions and 110 deletions.
6 changes: 3 additions & 3 deletions cmd/kube-apiserver/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func CreateNodeDialer(s completedServerRunOptions) (tunneler.Tunneler, *http.Tra
// Proxying to pods and services is IP-based... don't expect to be able to verify the hostname
proxyTLSClientConfig := &tls.Config{InsecureSkipVerify: true}
proxyTransport := utilnet.SetTransportDefaults(&http.Transport{
Dial: proxyDialerFn,
DialContext: proxyDialerFn,
TLSClientConfig: proxyTLSClientConfig,
})
return nodeTunneler, proxyTransport, nil
Expand Down Expand Up @@ -522,8 +522,8 @@ func BuildGenericConfig(
if err != nil {
return nil, err
}
if proxyTransport != nil && proxyTransport.Dial != nil {
ret.Dial = proxyTransport.Dial
if proxyTransport != nil && proxyTransport.DialContext != nil {
ret.Dial = proxyTransport.DialContext
}
return ret, err
},
Expand Down
2 changes: 1 addition & 1 deletion pkg/kubelet/client/kubelet_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func MakeTransport(config *KubeletClientConfig) (http.RoundTripper, error) {
rt := http.DefaultTransport
if config.Dial != nil || tlsConfig != nil {
rt = utilnet.SetOldTransportDefaults(&http.Transport{
Dial: config.Dial,
DialContext: config.Dial,
TLSClientConfig: tlsConfig,
})
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/master/master_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package master

import (
"context"
"crypto/tls"
"encoding/json"
"io/ioutil"
Expand Down Expand Up @@ -108,7 +109,7 @@ func setUp(t *testing.T) (*etcdtesting.EtcdTestServer, Config, informers.SharedI
config.GenericConfig.LoopbackClientConfig = &restclient.Config{APIPath: "/api", ContentConfig: restclient.ContentConfig{NegotiatedSerializer: legacyscheme.Codecs}}
config.ExtraConfig.KubeletClientConfig = kubeletclient.KubeletClientConfig{Port: 10250}
config.ExtraConfig.ProxyTransport = utilnet.SetTransportDefaults(&http.Transport{
Dial: func(network, addr string) (net.Conn, error) { return nil, nil },
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil },
TLSClientConfig: &tls.Config{},
})

Expand Down
6 changes: 3 additions & 3 deletions pkg/master/tunneler/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type AddressFunc func() (addresses []string, err error)
type Tunneler interface {
Run(AddressFunc)
Stop()
Dial(net, addr string) (net.Conn, error)
Dial(ctx context.Context, net, addr string) (net.Conn, error)
SecondsSinceSync() int64
SecondsSinceSSHKeySync() int64
}
Expand Down Expand Up @@ -149,8 +149,8 @@ func (c *SSHTunneler) Stop() {
}
}

func (c *SSHTunneler) Dial(net, addr string) (net.Conn, error) {
return c.tunnels.Dial(net, addr)
func (c *SSHTunneler) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
return c.tunnels.Dial(ctx, net, addr)
}

func (c *SSHTunneler) SecondsSinceSync() int64 {
Expand Down
11 changes: 6 additions & 5 deletions pkg/master/tunneler/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package tunneler

import (
"context"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -111,11 +112,11 @@ type FakeTunneler struct {
SecondsSinceSSHKeySyncValue int64
}

func (t *FakeTunneler) Run(AddressFunc) {}
func (t *FakeTunneler) Stop() {}
func (t *FakeTunneler) Dial(net, addr string) (net.Conn, error) { return nil, nil }
func (t *FakeTunneler) SecondsSinceSync() int64 { return t.SecondsSinceSyncValue }
func (t *FakeTunneler) SecondsSinceSSHKeySync() int64 { return t.SecondsSinceSSHKeySyncValue }
func (t *FakeTunneler) Run(AddressFunc) {}
func (t *FakeTunneler) Stop() {}
func (t *FakeTunneler) Dial(ctx context.Context, net, addr string) (net.Conn, error) { return nil, nil }
func (t *FakeTunneler) SecondsSinceSync() int64 { return t.SecondsSinceSyncValue }
func (t *FakeTunneler) SecondsSinceSSHKeySync() int64 { return t.SecondsSinceSSHKeySyncValue }

// TestIsTunnelSyncHealthy verifies that the 600 second lag test
// is honored.
Expand Down
12 changes: 7 additions & 5 deletions pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ssh

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
Expand Down Expand Up @@ -121,10 +122,11 @@ func (s *SSHTunnel) Open() error {
return err
}

func (s *SSHTunnel) Dial(network, address string) (net.Conn, error) {
func (s *SSHTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if s.client == nil {
return nil, errors.New("tunnel is not opened.")
}
// This Dial method does not allow to pass a context unfortunately
return s.client.Dial(network, address)
}

Expand Down Expand Up @@ -294,7 +296,7 @@ func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
type tunnel interface {
Open() error
Close() error
Dial(network, address string) (net.Conn, error)
Dial(ctx context.Context, network, address string) (net.Conn, error)
}

type sshTunnelEntry struct {
Expand Down Expand Up @@ -361,7 +363,7 @@ func (l *SSHTunnelList) delayedHealthCheck(e sshTunnelEntry, delay time.Duration
func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error {
// GET the healthcheck path using the provided tunnel's dial function.
transport := utilnet.SetTransportDefaults(&http.Transport{
Dial: e.Tunnel.Dial,
DialContext: e.Tunnel.Dial,
// TODO(cjcullen): Plumb real TLS options through.
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
// We don't reuse the clients, so disable the keep-alive to properly
Expand Down Expand Up @@ -394,7 +396,7 @@ func (l *SSHTunnelList) removeAndReAdd(e sshTunnelEntry) {
go l.createAndAddTunnel(e.Address)
}

func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) {
func (l *SSHTunnelList) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
start := time.Now()
id := mathrand.Int63() // So you can match begins/ends in the log.
glog.Infof("[%x: %v] Dialing...", id, addr)
Expand All @@ -405,7 +407,7 @@ func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) {
if err != nil {
return nil, err
}
return tunnel.Dial(net, addr)
return tunnel.Dial(ctx, net, addr)
}

func (l *SSHTunnelList) pickTunnel(addr string) (tunnel, error) {
Expand Down
5 changes: 3 additions & 2 deletions pkg/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package ssh

import (
"context"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -145,7 +146,7 @@ func TestSSHTunnel(t *testing.T) {
t.FailNow()
}

_, err = tunnel.Dial("tcp", "127.0.0.1:8080")
_, err = tunnel.Dial(context.Background(), "tcp", "127.0.0.1:8080")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -176,7 +177,7 @@ func (*fakeTunnel) Close() error {
return nil
}

func (*fakeTunnel) Dial(network, address string) (net.Conn, error) {
func (*fakeTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package spdy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -118,7 +119,7 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
}

if proxyURL == nil {
return s.dialWithoutProxy(req.URL)
return s.dialWithoutProxy(req.Context(), req.URL)
}

// ensure we use a canonical host with proxyReq
Expand All @@ -136,7 +137,7 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
proxyReq.Header.Set("Proxy-Authorization", pa)
}

proxyDialConn, err := s.dialWithoutProxy(proxyURL)
proxyDialConn, err := s.dialWithoutProxy(req.Context(), proxyURL)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -187,14 +188,15 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
}

// dialWithoutProxy dials the host specified by url, using TLS if appropriate.
func (s *SpdyRoundTripper) dialWithoutProxy(url *url.URL) (net.Conn, error) {
func (s *SpdyRoundTripper) dialWithoutProxy(ctx context.Context, url *url.URL) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)

if url.Scheme == "http" {
if s.Dialer == nil {
return net.Dial("tcp", dialAddr)
var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
} else {
return s.Dialer.Dial("tcp", dialAddr)
return s.Dialer.DialContext(ctx, "tcp", dialAddr)
}
}

Expand Down
9 changes: 5 additions & 4 deletions staging/src/k8s.io/apimachinery/pkg/util/net/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package net
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
Expand Down Expand Up @@ -90,8 +91,8 @@ func SetOldTransportDefaults(t *http.Transport) *http.Transport {
// ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY
t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
}
if t.Dial == nil {
t.Dial = defaultTransport.Dial
if t.DialContext == nil {
t.DialContext = defaultTransport.DialContext
}
if t.TLSHandshakeTimeout == 0 {
t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
Expand Down Expand Up @@ -119,7 +120,7 @@ type RoundTripperWrapper interface {
WrappedRoundTripper() http.RoundTripper
}

type DialFunc func(net, addr string) (net.Conn, error)
type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error)

func DialerFor(transport http.RoundTripper) (DialFunc, error) {
if transport == nil {
Expand All @@ -128,7 +129,7 @@ func DialerFor(transport http.RoundTripper) (DialFunc, error) {

switch transport := transport.(type) {
case *http.Transport:
return transport.Dial, nil
return transport.DialContext, nil
case RoundTripperWrapper:
return DialerFor(transport.WrappedRoundTripper())
default:
Expand Down
12 changes: 7 additions & 5 deletions staging/src/k8s.io/apimachinery/pkg/util/proxy/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package proxy

import (
"context"
"crypto/tls"
"fmt"
"net"
Expand All @@ -29,7 +30,7 @@ import (
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)

func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)

dialer, err := utilnet.DialerFor(transport)
Expand All @@ -40,9 +41,10 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
switch url.Scheme {
case "http":
if dialer != nil {
return dialer("tcp", dialAddr)
return dialer(ctx, "tcp", dialAddr)
}
return net.Dial("tcp", dialAddr)
var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
case "https":
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
Expand All @@ -56,7 +58,7 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
if dialer != nil {
// We have a dialer; use it to open the connection, then
// create a tls client using the connection.
netConn, err := dialer("tcp", dialAddr)
netConn, err := dialer(ctx, "tcp", dialAddr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -86,7 +88,7 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
}

} else {
// Dial
// Dial. This Dial method does not allow to pass a context unfortunately
tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err
Expand Down
16 changes: 9 additions & 7 deletions staging/src/k8s.io/apimachinery/pkg/util/proxy/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package proxy

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand All @@ -42,6 +43,7 @@ func TestDialURL(t *testing.T) {
if err != nil {
t.Fatal(err)
}
var d net.Dialer

testcases := map[string]struct {
TLSConfig *tls.Config
Expand All @@ -68,25 +70,25 @@ func TestDialURL(t *testing.T) {

"insecure, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
Dial: net.Dial,
Dial: d.DialContext,
},
"secure, no roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false},
Dial: net.Dial,
Dial: d.DialContext,
ExpectError: "unknown authority",
},
"secure with roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots},
Dial: net.Dial,
Dial: d.DialContext,
},
"secure with mismatched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"},
Dial: net.Dial,
Dial: d.DialContext,
ExpectError: "not bogus.com",
},
"secure with matched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
Dial: net.Dial,
Dial: d.DialContext,
},
}

Expand All @@ -102,7 +104,7 @@ func TestDialURL(t *testing.T) {
// Clone() mutates the receiver (!), so also call it on the copy
tlsConfigCopy.Clone()
transport := &http.Transport{
Dial: tc.Dial,
DialContext: tc.Dial,
TLSClientConfig: tlsConfigCopy,
}

Expand All @@ -125,7 +127,7 @@ func TestDialURL(t *testing.T) {
u, _ := url.Parse(ts.URL)
_, p, _ := net.SplitHostPort(u.Host)
u.Host = net.JoinHostPort("127.0.0.1", p)
conn, err := DialURL(u, transport)
conn, err := DialURL(context.Background(), u, transport)

// Make sure dialing doesn't mutate the transport's TLSConfig
if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func (h *UpgradeAwareHandler) DialForUpgrade(req *http.Request) (net.Conn, error

// dial dials the backend at req.URL and writes req to it.
func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) {
conn, err := DialURL(req.URL, transport)
conn, err := DialURL(req.Context(), req.URL, transport)
if err != nil {
return nil, fmt.Errorf("error dialing backend: %v", err)
}
Expand Down
Loading

0 comments on commit 5e8e570

Please sign in to comment.