Skip to content

Commit

Permalink
add chain.DialContext
Browse files Browse the repository at this point in the history
  • Loading branch information
ginuerzh committed Feb 8, 2020
1 parent 425099a commit abe4043
Show file tree
Hide file tree
Showing 17 changed files with 414 additions and 375 deletions.
50 changes: 36 additions & 14 deletions chain.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gost

import (
"context"
"errors"
"net"
"time"
Expand Down Expand Up @@ -100,9 +101,14 @@ func (c *Chain) IsEmpty() bool {
return c == nil || len(c.nodeGroups) == 0
}

// Dial connects to the target address addr through the chain.
// If the chain is empty, it will use the net.Dial directly.
func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error) {
// Dial connects to the target TCP address addr through the chain.
// Deprecated: use DialContext instead.
func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) {
return c.DialContext(context.Background(), "tcp", address, opts...)
}

// DialContext connects to the address on the named network using the provided context.
func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{}
for _, opt := range opts {
opt(options)
Expand All @@ -117,41 +123,55 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error
}

for i := 0; i < retries; i++ {
conn, err = c.dialWithOptions(addr, options)
conn, err = c.dialWithOptions(ctx, network, address, options)
if err == nil {
break
}
}
return
}

func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, error) {
func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) {
if options == nil {
options = &ChainOptions{}
}
route, err := c.selectRouteFor(addr)
route, err := c.selectRouteFor(address)
if err != nil {
return nil, err
}

ipAddr := c.resolve(addr, options.Resolver, options.Hosts)
ipAddr := address
if address != "" {
ipAddr = c.resolve(address, options.Resolver, options.Hosts)
}

timeout := options.Timeout
if timeout <= 0 {
timeout = DialTimeout
}

if route.IsEmpty() {
return net.DialTimeout("tcp", ipAddr, timeout)
switch network {
case "udp", "udp4", "udp6":
if address == "" {
return net.ListenUDP(network, nil)
}
default:
}
d := &net.Dialer{
Timeout: timeout,
// LocalAddr: laddr, // TODO: optional local address
}
return d.DialContext(ctx, network, ipAddr)
}

conn, err := route.getConn()
conn, err := route.getConn(ctx)
if err != nil {
return nil, err
}

cOpts := append([]ConnectOption{AddrConnectOption(addr)}, route.LastNode().ConnectOptions...)
cc, err := route.LastNode().Client.Connect(conn, ipAddr, cOpts...)
cOpts := append([]ConnectOption{AddrConnectOption(address)}, route.LastNode().ConnectOptions...)
cc, err := route.LastNode().Client.ConnectContext(ctx, conn, network, ipAddr, cOpts...)
if err != nil {
conn.Close()
return nil, err
Expand Down Expand Up @@ -187,6 +207,8 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
opt(options)
}

ctx := context.Background()

retries := 1
if c != nil && c.Retries > 0 {
retries = c.Retries
Expand All @@ -201,7 +223,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
if err != nil {
continue
}
conn, err = route.getConn()
conn, err = route.getConn(ctx)
if err == nil {
break
}
Expand All @@ -210,7 +232,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
}

// getConn obtains a connection to the last node of the chain.
func (c *Chain) getConn() (conn net.Conn, err error) {
func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) {
if c.IsEmpty() {
err = ErrEmptyChain
return
Expand All @@ -234,7 +256,7 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
preNode := node
for _, node := range nodes[1:] {
var cc net.Conn
cc, err = preNode.Client.Connect(cn, node.Addr, preNode.ConnectOptions...)
cc, err = preNode.Client.ConnectContext(ctx, cn, "tcp", node.Addr, preNode.ConnectOptions...)
if err != nil {
cn.Close()
node.MarkDead()
Expand Down
51 changes: 33 additions & 18 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gost

import (
"context"
"crypto/tls"
"net"
"net/url"
Expand All @@ -14,23 +15,8 @@ import (
// Connector is responsible for connecting to the destination address through this proxy.
// Transporter performs a handshake with this proxy.
type Client struct {
Connector Connector
Transporter Transporter
}

// Dial connects to the target address.
func (c *Client) Dial(addr string, options ...DialOption) (net.Conn, error) {
return c.Transporter.Dial(addr, options...)
}

// Handshake performs a handshake with the proxy over connection conn.
func (c *Client) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
return c.Transporter.Handshake(conn, options...)
}

// Connect connects to the address addr via the proxy over connection conn.
func (c *Client) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return c.Connector.Connect(conn, addr, options...)
Connector
Transporter
}

// DefaultClient is a standard HTTP proxy client.
Expand All @@ -53,7 +39,36 @@ func Connect(conn net.Conn, addr string) (net.Conn, error) {

// Connector is responsible for connecting to the destination address.
type Connector interface {
Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error)
// Deprecated: use ConnectContext instead.
Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error)
ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error)
}

type autoConnector struct {
User *url.Userinfo
}

// AutoConnector is a Connector.
func AutoConnector(user *url.Userinfo) Connector {
return &autoConnector{
User: user,
}
}

func (c *autoConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}

func (c *autoConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
var cnr Connector
switch network {
case "tcp", "tcp4", "tcp6":
cnr = &httpConnector{User: c.User}
default:
cnr = &socks5UDPTunConnector{User: c.User}
}

return cnr.ConnectContext(ctx, conn, network, address, options...)
}

// Transporter is responsible for handshaking with the proxy server.
Expand Down
5 changes: 2 additions & 3 deletions cmd/gost/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
case "sni":
connector = gost.SNIConnector(node.Get("host"))
case "http":
fallthrough
default:
node.Protocol = "http" // default protocol is HTTP
connector = gost.HTTPConnector(node.User)
default:
connector = gost.AutoConnector(node.User)
}

timeout := node.GetInt("timeout")
Expand Down
43 changes: 9 additions & 34 deletions forward.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gost

import (
"context"
"errors"
"net"
"strings"
Expand All @@ -22,7 +23,11 @@ func ForwardConnector() Connector {
return &forwardConnector{}
}

func (c *forwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *forwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return conn, nil
}

func (c *forwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
return conn, nil
}

Expand Down Expand Up @@ -186,42 +191,12 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
return
}

raddr, err := net.ResolveUDPAddr("udp", node.Addr)
cc, err := h.options.Chain.DialContext(context.Background(), "udp", node.Addr)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}

var cc net.Conn
if h.options.Chain.IsEmpty() {
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
} else if h.options.Chain.LastNode().Protocol == "ssu" {
cc, err = h.options.Chain.Dial(node.Addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
} else {
var err error
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
if err != nil {
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}

cc = &udpTunnelConn{Conn: cc, raddr: raddr}
}

defer cc.Close()
node.ResetDead()

Expand Down Expand Up @@ -726,11 +701,11 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
lastNode := l.chain.LastNode()
if lastNode.Protocol == "socks5" {
var cc net.Conn
cc, err = getSOCKS5UDPTunnel(l.chain, l.addr)
cc, err = getSocks5UDPTunnel(l.chain, l.addr)
if err != nil {
log.Logf("[rudp] %s : %s", l.Addr(), err)
} else {
conn = &udpTunnelConn{Conn: cc}
conn = cc.(net.PacketConn)
}
} else {
var uc *net.UDPConn
Expand Down
2 changes: 1 addition & 1 deletion gost.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
)

// Version is the gost version.
const Version = "2.10.0"
const Version = "2.10.1"

// Debug is a flag that enables the debug log.
var Debug bool
Expand Down
16 changes: 13 additions & 3 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gost
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
"net"
Expand All @@ -27,7 +28,16 @@ func HTTPConnector(user *url.Userinfo) Connector {
return &httpConnector{User: user}
}

func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *httpConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}

func (c *httpConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}

opts := &ConnectOptions{}
for _, option := range options {
option(opts)
Expand All @@ -47,8 +57,8 @@ func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOp

req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
Host: addr,
URL: &url.URL{Host: address},
Host: address,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Expand Down
16 changes: 13 additions & 3 deletions http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gost
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
Expand Down Expand Up @@ -33,7 +34,16 @@ func HTTP2Connector(user *url.Userinfo) Connector {
return &http2Connector{User: user}
}

func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *http2Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}

func (c *http2Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}

opts := &ConnectOptions{}
for _, option := range options {
option(opts)
Expand All @@ -57,7 +67,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO
ProtoMajor: 2,
ProtoMinor: 0,
Body: pr,
Host: addr,
Host: address,
ContentLength: -1,
}
req.Header.Set("User-Agent", ua)
Expand Down Expand Up @@ -97,7 +107,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO
closed: make(chan struct{}),
}

hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr)
hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", address)
hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr)

return hc, nil
Expand Down
Loading

0 comments on commit abe4043

Please sign in to comment.