Skip to content

Commit

Permalink
Split protecting and resolving
Browse files Browse the repository at this point in the history
  • Loading branch information
oxtoacart committed Feb 5, 2016
1 parent 80539d0 commit 735d7a9
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 97 deletions.
181 changes: 88 additions & 93 deletions src/github.com/getlantern/protected/protected.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@ type SocketProtector interface {

type ProtectedConn struct {
net.Conn
mutex sync.Mutex
protector SocketProtector
isClosed bool
socketFd int
addr string
host string
ip [4]byte
port int
mutex sync.Mutex
isClosed bool
socketFd int
ip [4]byte
port int
}

var (
Expand All @@ -55,7 +52,79 @@ func Configure(protector SocketProtector, dnsServer string) {
}
}

// Dial creates a new protected connection
// Resolve resolves the given address using a DNS lookup on a UDP socket
// protected by the currnet Protector.
func Resolve(addr string) (*net.TCPAddr, error) {
host, port, err := SplitHostPort(addr)
if err != nil {
return nil, err
}

// Check if we already have the IP address
IPAddr := net.ParseIP(host)
if IPAddr != nil {
return &net.TCPAddr{IP: IPAddr, Port: port}, nil
}

// Create a datagram socket
socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0)
if err != nil {
return nil, fmt.Errorf("Error creating socket: %v", err)
}
defer syscall.Close(socketFd)

// Here we protect the underlying socket from the
// VPN connection by passing the file descriptor
// back to Java for exclusion
err = currentProtector.Protect(socketFd)
if err != nil {
return nil, fmt.Errorf("Could not bind socket to system device: %v", err)
}

IPAddr = net.ParseIP(currentDnsServer)
if IPAddr == nil {
return nil, errors.New("invalid IP address")
}

var ip [4]byte
copy(ip[:], IPAddr.To4())
sockAddr := syscall.SockaddrInet4{Addr: ip, Port: dnsPort}

err = syscall.Connect(socketFd, &sockAddr)
if err != nil {
return nil, err
}

fd := uintptr(socketFd)
file := os.NewFile(fd, "")
defer file.Close()

// return a copy of the network connection
// represented by file
fileConn, err := net.FileConn(file)
if err != nil {
log.Errorf("Error returning a copy of the network connection: %v", err)
return nil, err
}

setQueryTimeouts(fileConn)

log.Debugf("performing dns lookup...!!")
result, err := dnsLookup(host, fileConn)
if err != nil {
log.Errorf("Error doing DNS resolution: %v", err)
return nil, err
}
ipAddr, err := result.PickRandomIP()
if err != nil {
log.Errorf("No IP address available: %v", err)
return nil, err
}
return &net.TCPAddr{IP: ipAddr, Port: port}, nil
}

// Dial creates a new protected connection, it assumes that the address has
// already been resolved to an IPv4 address.
// - syscall API calls are used to create and bind to the
// specified system device (this is primarily
// used for Android VpnService routing functionality)
Expand All @@ -66,33 +135,30 @@ func Dial(network, addr string, timeout time.Duration) (*ProtectedConn, error) {
}

conn := &ProtectedConn{
addr: addr,
host: host,
port: port,
protector: currentProtector,
port: port,
}
// do DNS query
IPAddr, err := conn.resolveHostname()
if err != nil {
log.Errorf("Couldn't resolve host %s: %s", conn.addr, err)
IPAddr := net.ParseIP(host)
if IPAddr == nil {
log.Errorf("Couldn't parse IP address %v", host)
return nil, err
}

copy(conn.ip[:], IPAddr.To4())

socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
if err != nil {
log.Errorf("Could not create socket: %s", err)
log.Errorf("Could not create socket: %v", err)
return nil, err
}
conn.socketFd = socketFd

defer conn.cleanup()

// Actually protect the underlying socket here
err = conn.protector.Protect(conn.socketFd)
err = currentProtector.Protect(conn.socketFd)
if err != nil {
return nil, fmt.Errorf("Could not bind socket to system device: %s", err)
return nil, fmt.Errorf("Could not bind socket to system device: %v", err)
}

err = conn.connectSocket()
Expand All @@ -104,7 +170,7 @@ func Dial(network, addr string, timeout time.Duration) (*ProtectedConn, error) {
// finally, convert the socket fd to a net.Conn
err = conn.convert()
if err != nil {
log.Errorf("Error converting protected connection: %s", err)
log.Errorf("Error converting protected connection: %v", err)
return nil, err
}

Expand All @@ -127,10 +193,6 @@ func (conn *ProtectedConn) connectSocket() error {
return err
}

func (conn *ProtectedConn) Addr() (*net.TCPAddr, error) {
return net.ResolveTCPAddr("tcp", conn.addr)
}

// converts the protected connection specified by
// socket fd to a net.Conn
func (conn *ProtectedConn) convert() error {
Expand Down Expand Up @@ -194,84 +256,17 @@ func setQueryTimeouts(c net.Conn) {
c.SetWriteDeadline(now.Add(writeDeadline))
}

// resolveHostname creates a UDP socket and binds it to the device
func (conn *ProtectedConn) resolveHostname() (net.IP, error) {

// Check if we already have the IP address
IPAddr := net.ParseIP(conn.host)
if IPAddr != nil {
return IPAddr, nil
}

// Create a datagram socket
socketFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0)
if err != nil {
log.Errorf("Error creating socket: %s", err)
return nil, err
}
defer syscall.Close(socketFd)

// Here we protect the underlying socket from the
// VPN connection by passing the file descriptor
// back to Java for exclusion
err = conn.protector.Protect(socketFd)
if err != nil {
return nil, fmt.Errorf("Could not bind socket to system device: %s", err)
}

IPAddr = net.ParseIP(currentDnsServer)
if IPAddr == nil {
return nil, errors.New("invalid IP address")
}

var ip [4]byte
copy(ip[:], IPAddr.To4())
sockAddr := syscall.SockaddrInet4{Addr: ip, Port: dnsPort}

err = syscall.Connect(socketFd, &sockAddr)
if err != nil {
return nil, err
}

fd := uintptr(socketFd)
file := os.NewFile(fd, "")
defer file.Close()

// return a copy of the network connection
// represented by file
fileConn, err := net.FileConn(file)
if err != nil {
log.Errorf("Error returning a copy of the network connection: %v", err)
return nil, err
}

setQueryTimeouts(fileConn)

log.Debugf("performing dns lookup...!!")
result, err := dnsLookup(conn.host, fileConn)
if err != nil {
log.Errorf("Error doing DNS resolution: %s", err)
return nil, err
}
ipAddr, err := result.PickRandomIP()
if err != nil {
log.Errorf("No IP address available: %s", err)
return nil, err
}
return ipAddr, nil
}

// wrapper around net.SplitHostPort that also converts
// uses strconv to convert the port to an int
func SplitHostPort(addr string) (string, int, error) {
host, sPort, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("Could not split network address: %s", err)
log.Errorf("Could not split network address: %v", err)
return "", 0, err
}
port, err := strconv.Atoi(sPort)
if err != nil {
log.Errorf("No port number found %s", err)
log.Errorf("No port number found %v", err)
return "", 0, err
}
return host, port, nil
Expand Down
6 changes: 5 additions & 1 deletion src/github.com/getlantern/protected/protected_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ func TestConnect(t *testing.T) {
client := &http.Client{
Transport: &http.Transport{
Dial: func(netw, addr string) (net.Conn, error) {
return Dial(netw, addr, 10*time.Second)
resolved, err := Resolve(addr)
if err != nil {
return nil, err
}
return Dial(netw, resolved.String(), 10*time.Second)
},
ResponseHeaderTimeout: time.Second * 2,
},
Expand Down
34 changes: 31 additions & 3 deletions src/github.com/getlantern/tlsdialer/tlsdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ import (

var (
log = golog.LoggerFor("tlsdialer")

resolve = func(addr string) (*net.TCPAddr, error) {
resolved, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
return resolved, nil
}

dialOverride func(network, addr string, timeout time.Duration) (net.Conn, error)
)

type timeoutError struct{}
Expand All @@ -40,6 +50,17 @@ type ConnWithTimings struct {
VerifiedChains [][]*x509.Certificate
}

// OverrideResolve allows overriding the DNS resolution function
func OverrideResolve(override func(addr string) (*net.TCPAddr, error)) {
resolve = override
}

// OverrideDial allows specifying a function that will be used to dial in lieu
// of a net.Dialer.
func OverrideDial(override func(network, addr string, timeout time.Duration) (net.Conn, error)) {
dialOverride = override
}

// Like crypto/tls.Dial, but with the ability to control whether or not to
// send the ServerName extension in client handshakes through the sendServerName
// flag.
Expand Down Expand Up @@ -94,12 +115,12 @@ func DialForTimings(dialer *net.Dialer, network, addr string, sendServerName boo
var err error
if timeout == 0 {
log.Tracef("Resolving immediately")
result.ResolvedAddr, err = net.ResolveTCPAddr("tcp", addr)
result.ResolvedAddr, err = resolve(addr)
} else {
log.Tracef("Resolving on goroutine")
resolvedCh := make(chan *net.TCPAddr, 10)
go func() {
resolved, err := net.ResolveTCPAddr("tcp", addr)
resolved, err := resolve(addr)
log.Tracef("Resolution resulted in %s : %s", resolved, err)
resolvedCh <- resolved
errCh <- err
Expand All @@ -119,7 +140,14 @@ func DialForTimings(dialer *net.Dialer, network, addr string, sendServerName boo

log.Tracef("Dialing %s %s (%s)", network, addr, result.ResolvedAddr)
start = time.Now()
rawConn, err := dialer.Dial(network, result.ResolvedAddr.String())
resolvedAddr := result.ResolvedAddr.String()
var rawConn net.Conn
if dialOverride != nil {
log.Trace("Dialing with dialOverride")
rawConn, err = dialOverride(network, resolvedAddr, timeout)
} else {
rawConn, err = dialer.Dial(network, resolvedAddr)
}
if err != nil {
return result, err
}
Expand Down

0 comments on commit 735d7a9

Please sign in to comment.