diff --git a/chain.go b/chain.go index 5b0cd459..4a511cb9 100644 --- a/chain.go +++ b/chain.go @@ -1,6 +1,7 @@ package gost import ( + "context" "errors" "net" "time" @@ -100,9 +101,14 @@ func (c *Chain) IsEmpty() bool { return c == nil || len(c.nodeGroups) == 0 } -// Dial connects to the target address addr through the chain. -// If the chain is empty, it will use the net.Dial directly. -func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error) { +// Dial connects to the target TCP address addr through the chain. +// Deprecated: use DialContext instead. +func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) { + return c.DialContext(context.Background(), "tcp", address, opts...) +} + +// DialContext connects to the address on the named network using the provided context. +func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) { options := &ChainOptions{} for _, opt := range opts { opt(options) @@ -117,7 +123,7 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error } for i := 0; i < retries; i++ { - conn, err = c.dialWithOptions(addr, options) + conn, err = c.dialWithOptions(ctx, network, address, options) if err == nil { break } @@ -125,16 +131,19 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error return } -func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, error) { +func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) { if options == nil { options = &ChainOptions{} } - route, err := c.selectRouteFor(addr) + route, err := c.selectRouteFor(address) if err != nil { return nil, err } - ipAddr := c.resolve(addr, options.Resolver, options.Hosts) + ipAddr := address + if address != "" { + ipAddr = c.resolve(address, options.Resolver, options.Hosts) + } timeout := options.Timeout if timeout <= 0 { @@ -142,16 +151,27 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e } if route.IsEmpty() { - return net.DialTimeout("tcp", ipAddr, timeout) + switch network { + case "udp", "udp4", "udp6": + if address == "" { + return net.ListenUDP(network, nil) + } + default: + } + d := &net.Dialer{ + Timeout: timeout, + // LocalAddr: laddr, // TODO: optional local address + } + return d.DialContext(ctx, network, ipAddr) } - conn, err := route.getConn() + conn, err := route.getConn(ctx) if err != nil { return nil, err } - cOpts := append([]ConnectOption{AddrConnectOption(addr)}, route.LastNode().ConnectOptions...) - cc, err := route.LastNode().Client.Connect(conn, ipAddr, cOpts...) + cOpts := append([]ConnectOption{AddrConnectOption(address)}, route.LastNode().ConnectOptions...) + cc, err := route.LastNode().Client.ConnectContext(ctx, conn, network, ipAddr, cOpts...) if err != nil { conn.Close() return nil, err @@ -187,6 +207,8 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { opt(options) } + ctx := context.Background() + retries := 1 if c != nil && c.Retries > 0 { retries = c.Retries @@ -201,7 +223,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { if err != nil { continue } - conn, err = route.getConn() + conn, err = route.getConn(ctx) if err == nil { break } @@ -210,7 +232,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { } // getConn obtains a connection to the last node of the chain. -func (c *Chain) getConn() (conn net.Conn, err error) { +func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) { if c.IsEmpty() { err = ErrEmptyChain return @@ -234,7 +256,7 @@ func (c *Chain) getConn() (conn net.Conn, err error) { preNode := node for _, node := range nodes[1:] { var cc net.Conn - cc, err = preNode.Client.Connect(cn, node.Addr, preNode.ConnectOptions...) + cc, err = preNode.Client.ConnectContext(ctx, cn, "tcp", node.Addr, preNode.ConnectOptions...) if err != nil { cn.Close() node.MarkDead() diff --git a/client.go b/client.go index a5c03e26..ad2e3668 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package gost import ( + "context" "crypto/tls" "net" "net/url" @@ -14,23 +15,8 @@ import ( // Connector is responsible for connecting to the destination address through this proxy. // Transporter performs a handshake with this proxy. type Client struct { - Connector Connector - Transporter Transporter -} - -// Dial connects to the target address. -func (c *Client) Dial(addr string, options ...DialOption) (net.Conn, error) { - return c.Transporter.Dial(addr, options...) -} - -// Handshake performs a handshake with the proxy over connection conn. -func (c *Client) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - return c.Transporter.Handshake(conn, options...) -} - -// Connect connects to the address addr via the proxy over connection conn. -func (c *Client) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { - return c.Connector.Connect(conn, addr, options...) + Connector + Transporter } // DefaultClient is a standard HTTP proxy client. @@ -53,7 +39,36 @@ func Connect(conn net.Conn, addr string) (net.Conn, error) { // Connector is responsible for connecting to the destination address. type Connector interface { - Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) + // Deprecated: use ConnectContext instead. + Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) + ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) +} + +type autoConnector struct { + User *url.Userinfo +} + +// AutoConnector is a Connector. +func AutoConnector(user *url.Userinfo) Connector { + return &autoConnector{ + User: user, + } +} + +func (c *autoConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *autoConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + var cnr Connector + switch network { + case "tcp", "tcp4", "tcp6": + cnr = &httpConnector{User: c.User} + default: + cnr = &socks5UDPTunConnector{User: c.User} + } + + return cnr.ConnectContext(ctx, conn, network, address, options...) } // Transporter is responsible for handshaking with the proxy server. diff --git a/cmd/gost/route.go b/cmd/gost/route.go index dc007662..d6a4f18d 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -227,10 +227,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { case "sni": connector = gost.SNIConnector(node.Get("host")) case "http": - fallthrough - default: - node.Protocol = "http" // default protocol is HTTP connector = gost.HTTPConnector(node.User) + default: + connector = gost.AutoConnector(node.User) } timeout := node.GetInt("timeout") diff --git a/forward.go b/forward.go index d7e93f81..f11ae342 100644 --- a/forward.go +++ b/forward.go @@ -1,6 +1,7 @@ package gost import ( + "context" "errors" "net" "strings" @@ -22,7 +23,11 @@ func ForwardConnector() Connector { return &forwardConnector{} } -func (c *forwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *forwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return conn, nil +} + +func (c *forwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { return conn, nil } @@ -186,42 +191,12 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) { return } - raddr, err := net.ResolveUDPAddr("udp", node.Addr) + cc, err := h.options.Chain.DialContext(context.Background(), "udp", node.Addr) if err != nil { node.MarkDead() log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) return } - - var cc net.Conn - if h.options.Chain.IsEmpty() { - cc, err = net.DialUDP("udp", nil, raddr) - if err != nil { - node.MarkDead() - log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) - return - } - } else if h.options.Chain.LastNode().Protocol == "ssu" { - cc, err = h.options.Chain.Dial(node.Addr, - RetryChainOption(h.options.Retries), - TimeoutChainOption(h.options.Timeout), - ) - if err != nil { - node.MarkDead() - log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) - return - } - } else { - var err error - cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) - if err != nil { - log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) - return - } - - cc = &udpTunnelConn{Conn: cc, raddr: raddr} - } - defer cc.Close() node.ResetDead() @@ -726,11 +701,11 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { lastNode := l.chain.LastNode() if lastNode.Protocol == "socks5" { var cc net.Conn - cc, err = getSOCKS5UDPTunnel(l.chain, l.addr) + cc, err = getSocks5UDPTunnel(l.chain, l.addr) if err != nil { log.Logf("[rudp] %s : %s", l.Addr(), err) } else { - conn = &udpTunnelConn{Conn: cc} + conn = cc.(net.PacketConn) } } else { var uc *net.UDPConn diff --git a/gost.go b/gost.go index 245903ab..cf75028c 100644 --- a/gost.go +++ b/gost.go @@ -20,7 +20,7 @@ import ( ) // Version is the gost version. -const Version = "2.10.0" +const Version = "2.10.1" // Debug is a flag that enables the debug log. var Debug bool diff --git a/http.go b/http.go index 979d1e7a..2027dc84 100644 --- a/http.go +++ b/http.go @@ -3,6 +3,7 @@ package gost import ( "bufio" "bytes" + "context" "encoding/base64" "fmt" "net" @@ -27,7 +28,16 @@ func HTTPConnector(user *url.Userinfo) Connector { return &httpConnector{User: user} } -func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *httpConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *httpConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -47,8 +57,8 @@ func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOp req := &http.Request{ Method: http.MethodConnect, - URL: &url.URL{Host: addr}, - Host: addr, + URL: &url.URL{Host: address}, + Host: address, ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), diff --git a/http2.go b/http2.go index 49174dda..ce4435f0 100644 --- a/http2.go +++ b/http2.go @@ -3,6 +3,7 @@ package gost import ( "bufio" "bytes" + "context" "crypto/tls" "encoding/base64" "errors" @@ -33,7 +34,16 @@ func HTTP2Connector(user *url.Userinfo) Connector { return &http2Connector{User: user} } -func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *http2Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *http2Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -57,7 +67,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO ProtoMajor: 2, ProtoMinor: 0, Body: pr, - Host: addr, + Host: address, ContentLength: -1, } req.Header.Set("User-Agent", ua) @@ -97,7 +107,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO closed: make(chan struct{}), } - hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr) + hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", address) hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr) return hc, nil diff --git a/redirect.go b/redirect.go index 5c83360e..f9597705 100644 --- a/redirect.go +++ b/redirect.go @@ -3,6 +3,7 @@ package gost import ( + "context" "errors" "fmt" "net" @@ -132,32 +133,14 @@ func (h *udpRedirectHandler) Handle(conn net.Conn) { return } - var cc net.Conn - var err error - if h.options.Chain.IsEmpty() { - cc, err = net.DialUDP("udp", nil, raddr) - if err != nil { - log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err) - return - } - } else if h.options.Chain.LastNode().Protocol == "ssu" { - cc, err = h.options.Chain.Dial(raddr.String(), - RetryChainOption(h.options.Retries), - TimeoutChainOption(h.options.Timeout), - ) - if err != nil { - log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err) - return - } - } else { - var err error - cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) - if err != nil { - log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err) - return - } - - cc = &udpTunnelConn{Conn: cc, raddr: raddr} + cc, err := h.options.Chain.DialContext(context.Background(), + "udp", raddr.String(), + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) + if err != nil { + log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err) + return } defer cc.Close() diff --git a/resolver.go b/resolver.go index aabd9d7c..99529a0c 100644 --- a/resolver.go +++ b/resolver.go @@ -606,31 +606,12 @@ func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger { } } -func (ex *dnsExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - if ex.options.chain.IsEmpty() { - d := &net.Dialer{ - Timeout: ex.options.timeout, - } - return d.DialContext(ctx, network, address) - } - - if ex.options.chain.LastNode().Protocol == "ssu" { - return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) - } - - raddr, err := net.ResolveUDPAddr(network, address) - if err != nil { - return - } - - cc, err := getSOCKS5UDPTunnel(ex.options.chain, nil) - conn = &udpTunnelConn{Conn: cc, raddr: raddr} - return -} - func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { t := time.Now() - c, err := ex.dial(ctx, "udp", ex.addr) + c, err := ex.options.chain.DialContext(ctx, + "udp", ex.addr, + TimeoutChainOption(ex.options.timeout), + ) if err != nil { return nil, err } @@ -674,19 +655,12 @@ func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger { } } -func (ex *dnsTCPExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - if ex.options.chain.IsEmpty() { - d := &net.Dialer{ - Timeout: ex.options.timeout, - } - return d.DialContext(ctx, network, address) - } - return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) -} - func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { t := time.Now() - c, err := ex.dial(ctx, "tcp", ex.addr) + c, err := ex.options.chain.DialContext(ctx, + "tcp", ex.addr, + TimeoutChainOption(ex.options.timeout), + ) if err != nil { return nil, err } @@ -738,14 +712,10 @@ func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption } func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - if ex.options.chain.IsEmpty() { - d := &net.Dialer{ - Timeout: ex.options.timeout, - } - conn, err = d.DialContext(ctx, network, address) - } else { - conn, err = ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) - } + conn, err = ex.options.chain.DialContext(ctx, + network, address, + TimeoutChainOption(ex.options.timeout), + ) if err != nil { return } @@ -812,14 +782,11 @@ func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOp return ex } -func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (conn net.Conn, err error) { - if ex.options.chain.IsEmpty() { - d := &net.Dialer{ - Timeout: ex.options.timeout, - } - return d.DialContext(ctx, network, address) - } - return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout)) +func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) { + return ex.options.chain.DialContext(ctx, + network, address, + TimeoutChainOption(ex.options.timeout), + ) } func (ex *dohExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { diff --git a/snapcraft.yaml b/snapcraft.yaml index 0f8a0ba3..54556e9a 100644 --- a/snapcraft.yaml +++ b/snapcraft.yaml @@ -1,6 +1,6 @@ name: gost type: app -version: '2.10.0' +version: '2.10.1' title: GO Simple Tunnel summary: A simple security tunnel written in golang description: | diff --git a/sni.go b/sni.go index 5a25cb3c..3cfeb26b 100644 --- a/sni.go +++ b/sni.go @@ -5,6 +5,7 @@ package gost import ( "bufio" "bytes" + "context" "encoding/base64" "encoding/binary" "errors" @@ -29,8 +30,17 @@ func SNIConnector(host string) Connector { return &sniConnector{host: host} } -func (c *sniConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { - return &sniClientConn{addr: addr, host: c.host, Conn: conn}, nil +func (c *sniConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *sniConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + return &sniClientConn{addr: address, host: c.host, Conn: conn}, nil } type sniHandler struct { diff --git a/socks.go b/socks.go index 3c765fae..acd3a39e 100644 --- a/socks.go +++ b/socks.go @@ -2,6 +2,7 @@ package gost import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -35,6 +36,10 @@ const ( CmdUDPTun uint8 = 0xF3 ) +var ( + _ net.PacketConn = (*socks5UDPTunnelConn)(nil) +) + type clientSelector struct { methods []uint8 User *url.Userinfo @@ -201,7 +206,17 @@ func SOCKS5Connector(user *url.Userinfo) Connector { return &socks5Connector{User: user} } -func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *socks5Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *socks5Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + cnr := &socks5UDPTunConnector{User: c.User} + return cnr.ConnectContext(ctx, conn, network, address, options...) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -229,7 +244,7 @@ func (c *socks5Connector) Connect(conn net.Conn, addr string, options ...Connect } conn = cc - host, port, err := net.SplitHostPort(addr) + host, port, err := net.SplitHostPort(address) if err != nil { return nil, err } @@ -273,7 +288,16 @@ func SOCKS5BindConnector(user *url.Userinfo) Connector { return &socks5BindConnector{User: user} } -func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *socks5BindConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *socks5BindConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -301,7 +325,7 @@ func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...Con } conn = cc - laddr, err := net.ResolveTCPAddr("tcp", addr) + laddr, err := net.ResolveTCPAddr("tcp", address) if err != nil { log.Log(err) return nil, err @@ -331,8 +355,8 @@ func (c *socks5BindConnector) Connect(conn net.Conn, addr string, options ...Con } if reply.Rep != gosocks5.Succeeded { - log.Logf("[socks5] bind on %s failure", addr) - return nil, fmt.Errorf("SOCKS5 bind on %s failure", addr) + log.Logf("[socks5] bind on %s failure", address) + return nil, fmt.Errorf("SOCKS5 bind on %s failure", address) } baddr, err := net.ResolveTCPAddr("tcp", reply.Addr.String()) if err != nil { @@ -350,8 +374,17 @@ func Socks5MuxBindConnector() Connector { return &socks5MuxBindConnector{} } +func (c *socks5MuxBindConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + // NOTE: the conn must be *muxBindClientConn. -func (c *socks5MuxBindConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *socks5MuxBindConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + accepter, ok := conn.(Accepter) if !ok { return nil, errors.New("wrong connection type") @@ -513,7 +546,16 @@ func SOCKS5UDPConnector(user *url.Userinfo) Connector { return &socks5UDPConnector{User: user} } -func (c *socks5UDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *socks5UDPConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "udp", address, options...) +} + +func (c *socks5UDPConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "tcp", "tcp4", "tcp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -541,7 +583,7 @@ func (c *socks5UDPConnector) Connect(conn net.Conn, addr string, options ...Conn } conn = cc - taddr, err := net.ResolveUDPAddr("udp", addr) + taddr, err := net.ResolveUDPAddr("udp", address) if err != nil { return nil, err } @@ -596,71 +638,40 @@ func SOCKS5UDPTunConnector(user *url.Userinfo) Connector { return &socks5UDPTunConnector{User: user} } -func (c *socks5UDPTunConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *socks5UDPTunConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "udp", address, options...) +} + +func (c *socks5UDPTunConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "tcp", "tcp4", "tcp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) } + user := opts.User + if user == nil { + user = c.User + } + timeout := opts.Timeout if timeout <= 0 { timeout = ConnectTimeout } - conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) - user := opts.User - if user == nil { - user = c.User - } - cc, err := socks5Handshake(conn, + taddr, _ := net.ResolveUDPAddr("udp", address) + return newSocks5UDPTunnelConn(conn, + nil, taddr, selectorSocks5HandshakeOption(opts.Selector), userSocks5HandshakeOption(user), noTLSSocks5HandshakeOption(opts.NoTLS), ) - if err != nil { - return nil, err - } - conn = cc - - taddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - - req := gosocks5.NewRequest(CmdUDPTun, &gosocks5.Addr{ - Type: gosocks5.AddrIPv4, - }) - - if err := req.Write(conn); err != nil { - return nil, err - } - - if Debug { - log.Log("[socks5] udp\n", req) - } - - reply, err := gosocks5.ReadReply(conn) - if err != nil { - return nil, err - } - - if Debug { - log.Log("[socks5] udp\n", reply) - } - - if reply.Rep != gosocks5.Succeeded { - log.Logf("[socks5] udp relay failure") - return nil, fmt.Errorf("SOCKS5 udp relay failure") - } - baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) - if err != nil { - return nil, err - } - log.Logf("[socks5] udp-tun associate on %s OK", baddr) - - return &udpTunnelConn{Conn: conn, raddr: taddr}, nil } type socks4Connector struct{} @@ -670,7 +681,16 @@ func SOCKS4Connector() Connector { return &socks4Connector{} } -func (c *socks4Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *socks4Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *socks4Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -684,7 +704,7 @@ func (c *socks4Connector) Connect(conn net.Conn, addr string, options ...Connect conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) - taddr, err := net.ResolveTCPAddr("tcp4", addr) + taddr, err := net.ResolveTCPAddr("tcp4", address) if err != nil { return nil, err } @@ -730,7 +750,16 @@ func SOCKS4AConnector() Connector { return &socks4aConnector{} } -func (c *socks4aConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *socks4aConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *socks4aConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -744,7 +773,7 @@ func (c *socks4aConnector) Connect(conn net.Conn, addr string, options ...Connec conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) - host, port, err := net.SplitHostPort(addr) + host, port, err := net.SplitHostPort(address) if err != nil { return nil, err } @@ -1601,6 +1630,7 @@ func (h *socks5Handler) muxBindOn(conn net.Conn, addr string) { } } +// TODO: support ipv6 and domain func toSocksAddr(addr net.Addr) *gosocks5.Addr { host := "0.0.0.0" port := 0 @@ -1795,52 +1825,6 @@ func (h *socks4Handler) handleBind(conn net.Conn, req *gosocks4.Request) { log.Logf("[socks4-bind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) } -func getSOCKS5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { - conn, err := chain.Conn() - if err != nil { - return nil, err - } - - conn.SetDeadline(time.Now().Add(HandshakeTimeout)) - defer conn.SetDeadline(time.Time{}) - - node := chain.LastNode() - cc, err := socks5Handshake(conn, - userSocks5HandshakeOption(node.User), - noTLSSocks5HandshakeOption(node.GetBool("notls")), - ) - if err != nil { - conn.Close() - return nil, err - } - conn = cc - - req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(addr)) - if err := req.Write(conn); err != nil { - conn.Close() - return nil, err - } - if Debug { - log.Log("[socks5]", req) - } - - reply, err := gosocks5.ReadReply(conn) - if err != nil { - conn.Close() - return nil, err - } - - if Debug { - log.Log("[socks5]", reply) - } - - if reply.Rep != gosocks5.Succeeded { - conn.Close() - return nil, errors.New("UDP tunnel failure") - } - return conn, nil -} - type socks5HandshakeOptions struct { selector gosocks5.Selector user *url.Userinfo @@ -1896,21 +1880,74 @@ func socks5Handshake(conn net.Conn, opts ...socks5HandshakeOption) (net.Conn, er return cc, nil } -type udpTunnelConn struct { - raddr net.Addr +func getSocks5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { + c, err := chain.Conn() + if err != nil { + return nil, err + } + + node := chain.LastNode() + conn, err := newSocks5UDPTunnelConn(c, + addr, nil, + userSocks5HandshakeOption(node.User), + noTLSSocks5HandshakeOption(node.GetBool("notls")), + ) + if err != nil { + c.Close() + } + return conn, nil +} + +type socks5UDPTunnelConn struct { net.Conn + taddr net.Addr } -func (c *udpTunnelConn) Read(b []byte) (n int, err error) { - dgram, err := gosocks5.ReadUDPDatagram(c.Conn) +func newSocks5UDPTunnelConn(conn net.Conn, raddr, taddr net.Addr, opts ...socks5HandshakeOption) (net.Conn, error) { + cc, err := socks5Handshake(conn, opts...) if err != nil { - return + return nil, err } - n = copy(b, dgram.Data) + + req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(raddr)) + if err := req.Write(cc); err != nil { + return nil, err + } + if Debug { + log.Log("[socks5] udp-tun", req) + } + + reply, err := gosocks5.ReadReply(cc) + if err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] udp-tun", reply) + } + + if reply.Rep != gosocks5.Succeeded { + return nil, errors.New("socks5 UDP tunnel failure") + } + + baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) + if err != nil { + return nil, err + } + log.Logf("[socks5] udp-tun associate on %s OK", baddr) + + return &socks5UDPTunnelConn{ + Conn: cc, + taddr: taddr, + }, nil +} + +func (c *socks5UDPTunnelConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) return } -func (c *udpTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { +func (c *socks5UDPTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { dgram, err := gosocks5.ReadUDPDatagram(c.Conn) if err != nil { return @@ -1920,15 +1957,11 @@ func (c *udpTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { return } -func (c *udpTunnelConn) Write(b []byte) (n int, err error) { - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(c.raddr)), b) - if err = dgram.Write(c.Conn); err != nil { - return - } - return len(b), nil +func (c *socks5UDPTunnelConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) } -func (c *udpTunnelConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { +func (c *socks5UDPTunnelConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b) if err = dgram.Write(c.Conn); err != nil { return diff --git a/ss.go b/ss.go index 58e3f24e..17bace47 100644 --- a/ss.go +++ b/ss.go @@ -2,6 +2,7 @@ package gost import ( "bytes" + "context" "encoding/binary" "fmt" "io" @@ -15,6 +16,15 @@ import ( ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" ) +const ( + maxSocksAddrLen = 259 +) + +var ( + _ net.Conn = (*shadowConn)(nil) + _ net.PacketConn = (*shadowUDPPacketConn)(nil) +) + type shadowConnector struct { cipher core.Cipher } @@ -27,7 +37,16 @@ func ShadowConnector(info *url.Userinfo) Connector { } } -func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *shadowConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *shadowConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -38,7 +57,7 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect timeout = ConnectTimeout } - socksAddr, err := gosocks5.NewAddr(addr) + socksAddr, err := gosocks5.NewAddr(address) if err != nil { return nil, err } @@ -183,7 +202,16 @@ func ShadowUDPConnector(info *url.Userinfo) Connector { } } -func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *shadowUDPConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "udp", address, options...) +} + +func (c *shadowUDPConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "tcp", "tcp4", "tcp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -197,13 +225,13 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn conn.SetDeadline(time.Now().Add(timeout)) defer conn.SetDeadline(time.Time{}) + taddr, _ := net.ResolveUDPAddr(network, address) + if taddr == nil { + taddr = &net.UDPAddr{} + } + pc, ok := conn.(net.PacketConn) if ok { - rawaddr, err := ss.RawAddr(addr) - if err != nil { - return nil, err - } - if c.cipher != nil { pc = c.cipher.PacketConn(pc) } @@ -211,22 +239,17 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn return &shadowUDPPacketConn{ PacketConn: pc, raddr: conn.RemoteAddr(), - header: rawaddr, + taddr: taddr, }, nil } - taddr, err := gosocks5.NewAddr(addr) - if err != nil { - return nil, err - } - if c.cipher != nil { conn = c.cipher.StreamConn(conn) } - return &shadowUDPStreamConn{ - Conn: conn, - addr: taddr, + return &socks5UDPTunnelConn{ + Conn: conn, + taddr: taddr, }, nil } @@ -258,23 +281,13 @@ func (h *shadowUDPHandler) Init(options ...HandlerOption) { func (h *shadowUDPHandler) Handle(conn net.Conn) { defer conn.Close() - var err error var cc net.PacketConn - if h.options.Chain.IsEmpty() { - cc, err = net.ListenUDP("udp", nil) - if err != nil { - log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err) - return - } - } else { - var c net.Conn - c, err = getSOCKS5UDPTunnel(h.options.Chain, nil) - if err != nil { - log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err) - return - } - cc = &udpTunnelConn{Conn: c} + c, err := h.options.Chain.DialContext(context.Background(), "udp", "") + if err != nil { + log.Logf("[ssu] %s: %s", conn.LocalAddr(), err) + return } + cc = c.(net.PacketConn) defer cc.Close() pc, ok := conn.(net.PacketConn) @@ -466,24 +479,11 @@ func (c *shadowConn) Write(b []byte) (n int, err error) { type shadowUDPPacketConn struct { net.PacketConn - raddr net.Addr - header []byte + raddr net.Addr + taddr net.Addr } -func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) { - n = len(b) // force byte length consistent - buf := bytes.Buffer{} - if _, err = buf.Write(c.header); err != nil { - return - } - if _, err = buf.Write(b); err != nil { - return - } - _, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr) - return -} - -func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) { +func (c *shadowUDPPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { buf := mPool.Get().([]byte) defer mPool.Put(buf) @@ -501,45 +501,44 @@ func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) { return } n = copy(b, dgram.Data) + addr, err = net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + return -} -func (c *shadowUDPPacketConn) RemoteAddr() net.Addr { - return c.raddr } -type shadowUDPStreamConn struct { - net.Conn - addr *gosocks5.Addr +func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return } -func (c *shadowUDPStreamConn) Read(b []byte) (n int, err error) { - dgram, err := gosocks5.ReadUDPDatagram(c.Conn) +func (c *shadowUDPPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + sa, err := gosocks5.NewAddr(addr.String()) + if err != nil { + return + } + var rawaddr [maxSocksAddrLen]byte + nn, err := sa.Encode(rawaddr[:]) if err != nil { return } - n = copy(b, dgram.Data) - return -} -func (c *shadowUDPStreamConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(b) - addr = c.Conn.RemoteAddr() + buf := mPool.Get().([]byte) + defer mPool.Put(buf) + + copy(buf, rawaddr[:nn]) + n = copy(buf[nn:], b) + _, err = c.PacketConn.WriteTo(buf[:n+nn], c.raddr) return } -func (c *shadowUDPStreamConn) Write(b []byte) (n int, err error) { - n = len(b) // force byte length consistent - dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, c.addr), b) - buf := bytes.Buffer{} - dgram.Write(&buf) - _, err = c.Conn.Write(buf.Bytes()) - return +func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) } -func (c *shadowUDPStreamConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - return c.Write(b) +func (c *shadowUDPPacketConn) RemoteAddr() net.Addr { + return c.raddr } type shadowCipher struct { diff --git a/ss_test.go b/ss_test.go index c57c6fd3..62484d04 100644 --- a/ss_test.go +++ b/ss_test.go @@ -138,7 +138,7 @@ var ssProxyTests = []struct { serverCipher *url.Userinfo pass bool }{ - {nil, nil, false}, + {nil, nil, true}, {&url.Userinfo{}, &url.Userinfo{}, true}, {url.User("abc"), url.User("abc"), true}, {url.UserPassword("abc", "def"), url.UserPassword("abc", "def"), true}, diff --git a/ssh.go b/ssh.go index d434c310..9fd78770 100644 --- a/ssh.go +++ b/ssh.go @@ -39,6 +39,15 @@ func SSHDirectForwardConnector() Connector { } func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", raddr, options...) +} + +func (c *sshDirectForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, raddr string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + opts := &ConnectOptions{} for _, option := range options { option(opts) @@ -73,7 +82,16 @@ func SSHRemoteForwardConnector() Connector { return &sshRemoteForwardConnector{} } -func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) { +func (c *sshRemoteForwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *sshRemoteForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. if !ok { return nil, errors.New("ssh: wrong connection type") @@ -87,10 +105,10 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options if cc.session == nil || cc.session.client == nil { return } - if strings.HasPrefix(addr, ":") { - addr = "0.0.0.0" + addr + if strings.HasPrefix(address, ":") { + address = "0.0.0.0" + address } - ln, err := cc.session.client.Listen("tcp", addr) + ln, err := cc.session.client.Listen("tcp", address) if err != nil { return } @@ -99,7 +117,7 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options for { rc, err := ln.Accept() if err != nil { - log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), addr, err) + log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), address, err) return } // log.Log("[ssh-rtcp] accept", rc.LocalAddr(), rc.RemoteAddr()) @@ -107,7 +125,7 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options case cc.session.connChan <- rc: default: rc.Close() - log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), addr) + log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), address) } } }() diff --git a/tuntap.go b/tuntap.go index d490a583..8f47b418 100644 --- a/tuntap.go +++ b/tuntap.go @@ -1,6 +1,7 @@ package gost import ( + "context" "errors" "fmt" "io" @@ -167,9 +168,11 @@ func (h *tunHandler) Handle(conn net.Conn) { var pc net.PacketConn // fake tcp mode will be ignored when the client specifies a chain. if raddr != nil && !h.options.Chain.IsEmpty() { - var cc net.Conn - cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) - pc = &udpTunnelConn{Conn: cc, raddr: raddr} + cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String()) + if err != nil { + return err + } + pc = cc.(net.PacketConn) } else { if h.options.TCPMode { if raddr != nil { @@ -549,9 +552,11 @@ func (h *tapHandler) Handle(conn net.Conn) { var pc net.PacketConn // fake tcp mode will be ignored when the client specifies a chain. if raddr != nil && !h.options.Chain.IsEmpty() { - var cc net.Conn - cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil) - pc = &udpTunnelConn{Conn: cc, raddr: raddr} + cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String()) + if err != nil { + return err + } + pc = cc.(net.PacketConn) } else { if h.options.TCPMode { if raddr != nil { diff --git a/udp.go b/udp.go index 34088ea4..5dccedfc 100644 --- a/udp.go +++ b/udp.go @@ -19,19 +19,17 @@ func UDPTransporter() Transporter { } func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - raddr, err := net.ResolveUDPAddr("udp", addr) + taddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - conn, err := net.ListenUDP("udp", nil) + conn, err := net.DialUDP("udp", nil, taddr) if err != nil { return nil, err } - return &udpClientConn{ UDPConn: conn, - raddr: raddr, }, nil } @@ -340,19 +338,14 @@ func (c *udpServerConn) SetWriteDeadline(t time.Time) error { type udpClientConn struct { *net.UDPConn - raddr net.Addr } -func (c *udpClientConn) Write(b []byte) (int, error) { - if c.raddr != nil { - return c.WriteTo(b, c.raddr) - } +func (c *udpClientConn) WriteTo(b []byte, addr net.Addr) (int, error) { return c.UDPConn.Write(b) } -func (c *udpClientConn) RemoteAddr() net.Addr { - if c.raddr != nil { - return c.raddr - } - return c.UDPConn.RemoteAddr() +func (c *udpClientConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(b) + addr = c.RemoteAddr() + return }