Skip to content

Commit

Permalink
love io.WriterTo txthinking#1
Browse files Browse the repository at this point in the history
  • Loading branch information
txthinking committed Mar 27, 2020
1 parent 9f204fc commit 6926f2b
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 105 deletions.
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,32 @@ $ go get github.com/txthinking/socks5
* Negotiation:
* `type NegotiationRequest struct`
* `func NewNegotiationRequest(methods []byte)`, in client
* `func (r *NegotiationRequest) WriteTo(w *net.TCPConn)`, client writes to server
* `func NewNegotiationRequestFrom(r *net.TCPConn)`, server reads from client
* `func (r *NegotiationRequest) WriteTo(w io.Writer)`, client writes to server
* `func NewNegotiationRequestFrom(r io.Reader)`, server reads from client
* `type NegotiationReply struct`
* `func NewNegotiationReply(method byte)`, in server
* `func (r *NegotiationReply) WriteTo(w *net.TCPConn)`, server writes to client
* `func NewNegotiationReplyFrom(r *net.TCPConn)`, client reads from server
* `func (r *NegotiationReply) WriteTo(w io.Writer)`, server writes to client
* `func NewNegotiationReplyFrom(r io.Reader)`, client reads from server
* User and password negotiation:
* `type UserPassNegotiationRequest struct`
* `func NewUserPassNegotiationRequest(username []byte, password []byte)`, in client
* `func (r *UserPassNegotiationRequest) WriteTo(w *net.TCPConn)`, client writes to server
* `func NewUserPassNegotiationRequestFrom(r *net.TCPConn)`, server reads from client
* `func (r *UserPassNegotiationRequest) WriteTo(w io.Writer)`, client writes to server
* `func NewUserPassNegotiationRequestFrom(r io.Reader)`, server reads from client
* `type UserPassNegotiationReply struct`
* `func NewUserPassNegotiationReply(status byte)`, in server
* `func (r *UserPassNegotiationReply) WriteTo(w *net.TCPConn)`, server writes to client
* `func NewUserPassNegotiationReplyFrom(r *net.TCPConn)`, client reads from server
* `func (r *UserPassNegotiationReply) WriteTo(w io.Writer)`, server writes to client
* `func NewUserPassNegotiationReplyFrom(r io.Reader)`, client reads from server
* Request:
* `type Request struct`
* `func NewRequest(cmd byte, atyp byte, dstaddr []byte, dstport []byte)`, in client
* `func (r *Request) WriteTo(w *net.TCPConn)`, client writes to server
* `func NewRequestFrom(r *net.TCPConn)`, server reads from client
* `func (r *Request) WriteTo(w io.Writer)`, client writes to server
* `func NewRequestFrom(r io.Reader)`, server reads from client
* After server gets the client's *Request, processes...
* Reply:
* `type Reply struct`
* `func NewReply(rep byte, atyp byte, bndaddr []byte, bndport []byte)`, in server
* `func (r *Reply) WriteTo(w *net.TCPConn)`, server writes to client
* `func NewReplyFrom(r *net.TCPConn)`, client reads from server
* `func (r *Reply) WriteTo(w io.Writer)`, server writes to client
* `func NewReplyFrom(r io.Reader)`, client reads from server
* Datagram:
* `type Datagram struct`
* `func NewDatagram(atyp byte, dstaddr []byte, dstport []byte, data []byte)`
Expand Down
6 changes: 3 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (c *Client) Negotiate() error {
m = MethodUsernamePassword
}
rq := NewNegotiationRequest([]byte{m})
if err := rq.WriteTo(c.TCPConn); err != nil {
if _, err := rq.WriteTo(c.TCPConn); err != nil {
return err
}
rp, err := NewNegotiationReplyFrom(c.TCPConn)
Expand All @@ -72,7 +72,7 @@ func (c *Client) Negotiate() error {
}
if m == MethodUsernamePassword {
urq := NewUserPassNegotiationRequest([]byte(c.UserName), []byte(c.Password))
if err := urq.WriteTo(c.TCPConn); err != nil {
if _, err := urq.WriteTo(c.TCPConn); err != nil {
return err
}
urp, err := NewUserPassNegotiationReplyFrom(c.TCPConn)
Expand All @@ -87,7 +87,7 @@ func (c *Client) Negotiate() error {
}

func (c *Client) Request(r *Request) (*Reply, error) {
if err := r.WriteTo(c.TCPConn); err != nil {
if _, err := r.WriteTo(c.TCPConn); err != nil {
return nil, err
}
rp, err := NewReplyFrom(c.TCPConn)
Expand Down
89 changes: 56 additions & 33 deletions client_side.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,27 @@ func NewNegotiationRequest(methods []byte) *NegotiationRequest {
}

// WriteTo write negotiation request packet into server
func (r *NegotiationRequest) WriteTo(w io.Writer) error {
if _, err := w.Write([]byte{r.Ver}); err != nil {
return err
}
if _, err := w.Write([]byte{r.NMethods}); err != nil {
return err
}
if _, err := w.Write(r.Methods); err != nil {
return err
func (r *NegotiationRequest) WriteTo(w io.Writer) (int64, error) {
var n int
i, err := w.Write([]byte{r.Ver})
n = n + i
if err != nil {
return int64(n), err
}
i, err = w.Write([]byte{r.NMethods})
n = n + i
if err != nil {
return int64(n), err
}
i, err = w.Write(r.Methods)
n = n + i
if err != nil {
return int64(n), err
}
if Debug {
log.Printf("Sent NegotiationRequest: %#v %#v %#v\n", r.Ver, r.NMethods, r.Methods)
}
return nil
return int64(n), nil
}

// NewNegotiationReplyFrom read negotiation reply packet from server
Expand Down Expand Up @@ -67,23 +74,32 @@ func NewUserPassNegotiationRequest(username []byte, password []byte) *UserPassNe
}

// WriteTo write user password negotiation request packet into server
func (r *UserPassNegotiationRequest) WriteTo(w io.Writer) error {
if _, err := w.Write([]byte{r.Ver, r.Ulen}); err != nil {
return err
}
if _, err := w.Write(r.Uname); err != nil {
return err
}
if _, err := w.Write([]byte{r.Plen}); err != nil {
return err
}
if _, err := w.Write(r.Passwd); err != nil {
return err
func (r *UserPassNegotiationRequest) WriteTo(w io.Writer) (int64, error) {
var n int
i, err := w.Write([]byte{r.Ver, r.Ulen})
n = n + i
if err != nil {
return int64(n), err
}
i, err = w.Write(r.Uname)
n = n + i
if err != nil {
return int64(n), err
}
i, err = w.Write([]byte{r.Plen})
n = n + i
if err != nil {
return int64(n), err
}
i, err = w.Write(r.Passwd)
n = n + i
if err != nil {
return int64(n), err
}
if Debug {
log.Printf("Sent UserNameNegotiationRequest: %#v %#v %#v %#v %#v\n", r.Ver, r.Ulen, r.Uname, r.Plen, r.Passwd)
}
return nil
return int64(n), nil
}

// NewUserPassNegotiationReplyFrom read user password negotiation reply packet from server
Expand Down Expand Up @@ -120,20 +136,27 @@ func NewRequest(cmd byte, atyp byte, dstaddr []byte, dstport []byte) *Request {
}

// WriteTo write request packet into server
func (r *Request) WriteTo(w io.Writer) error {
if _, err := w.Write([]byte{r.Ver, r.Cmd, r.Rsv, r.Atyp}); err != nil {
return err
}
if _, err := w.Write(r.DstAddr); err != nil {
return err
}
if _, err := w.Write(r.DstPort); err != nil {
return err
func (r *Request) WriteTo(w io.Writer) (int64, error) {
var n int
i, err := w.Write([]byte{r.Ver, r.Cmd, r.Rsv, r.Atyp})
n = n + i
if err != nil {
return int64(n), err
}
i, err = w.Write(r.DstAddr)
n = n + i
if err != nil {
return int64(n), err
}
i, err = w.Write(r.DstPort)
n = n + i
if err != nil {
return int64(n), err
}
if Debug {
log.Printf("Sent Request: %#v %#v %#v %#v %#v %#v\n", r.Ver, r.Cmd, r.Rsv, r.Atyp, r.DstAddr, r.DstPort)
}
return nil
return int64(n), nil
}

// NewReplyFrom read reply packet from server
Expand Down
6 changes: 3 additions & 3 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (r *Request) Connect(c *net.TCPConn) (*net.TCPConn, error) {
} else {
p = NewReply(RepHostUnreachable, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00})
}
if err := p.WriteTo(c); err != nil {
if _, err := p.WriteTo(c); err != nil {
return nil, err
}
return nil, err
Expand All @@ -34,13 +34,13 @@ func (r *Request) Connect(c *net.TCPConn) (*net.TCPConn, error) {
} else {
p = NewReply(RepHostUnreachable, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00})
}
if err := p.WriteTo(c); err != nil {
if _, err := p.WriteTo(c); err != nil {
return nil, err
}
return nil, err
}
p := NewReply(RepSuccess, a, addr, port)
if err := p.WriteTo(c); err != nil {
if _, err := p.WriteTo(c); err != nil {
return nil, err
}

Expand Down
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package socks5_test
import "github.com/txthinking/socks5"

func ExampleServer() {
s, err := socks5.NewClassicServer("127.0.0.1:1080", "127.0.0.1", "", "", 60, 0, 60, 60)
s, err := socks5.NewClassicServer("127.0.0.1:1081", "127.0.0.1", "", "", 60, 0, 60, 60)
if err != nil {
panic(err)
}
Expand Down
67 changes: 35 additions & 32 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ func (s *Server) Negotiate(c *net.TCPConn) error {
}
if !got {
rp := NewNegotiationReply(MethodUnsupportAll)
if err := rp.WriteTo(c); err != nil {
if _, err := rp.WriteTo(c); err != nil {
return err
}
}
rp := NewNegotiationReply(s.Method)
if err := rp.WriteTo(c); err != nil {
if _, err := rp.WriteTo(c); err != nil {
return err
}

Expand All @@ -117,13 +117,13 @@ func (s *Server) Negotiate(c *net.TCPConn) error {
}
if string(urq.Uname) != s.UserName || string(urq.Passwd) != s.Password {
urp := NewUserPassNegotiationReply(UserPassStatusFailure)
if err := urp.WriteTo(c); err != nil {
if _, err := urp.WriteTo(c); err != nil {
return err
}
return ErrUserPassAuth
}
urp := NewUserPassNegotiationReply(UserPassStatusSuccess)
if err := urp.WriteTo(c); err != nil {
if _, err := urp.WriteTo(c); err != nil {
return err
}
}
Expand Down Expand Up @@ -151,7 +151,7 @@ func (s *Server) GetRequest(c *net.TCPConn) (*Request, error) {
} else {
p = NewReply(RepCommandNotSupported, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00})
}
if err := p.WriteTo(c); err != nil {
if _, err := p.WriteTo(c); err != nil {
return nil, err
}
return nil, ErrUnsupportCmd
Expand Down Expand Up @@ -268,6 +268,32 @@ func (s *Server) Stop() error {
return err1
}

// TCP connection waits for associated UDP to close
func (s *Server) TCPWaitsForUDP(addr *net.UDPAddr) error {
_, p, err := net.SplitHostPort(addr.String())
if err != nil {
return err
}
if p == "0" {
time.Sleep(time.Duration(s.UDPSessionTime) * time.Second)
return nil
}
ch := make(chan byte)
s.TCPUDPAssociate.Set(addr.String(), ch, cache.DefaultExpiration)
<-ch
return nil
}

// UDP releases associated TCP
func (s *Server) UDPReleasesTCP(addr *net.UDPAddr) {
v, ok := s.TCPUDPAssociate.Get(addr.String())
if ok {
ch := v.(chan byte)
ch <- 0x00
s.TCPUDPAssociate.Delete(addr.String())
}
}

// Handler handle tcp, udp request
type Handler interface {
// Request has not been replied yet
Expand Down Expand Up @@ -326,17 +352,9 @@ func (h *DefaultHandle) TCPHandle(s *Server, c *net.TCPConn, r *Request) error {
if err != nil {
return err
}
_, p, err := net.SplitHostPort(caddr.String())
if err != nil {
if err := s.TCPWaitsForUDP(caddr); err != nil {
return err
}
if p == "0" {
time.Sleep(time.Duration(s.UDPSessionTime) * time.Second)
return nil
}
ch := make(chan byte)
s.TCPUDPAssociate.Set(caddr.String(), ch, cache.DefaultExpiration)
<-ch
return nil
}
return ErrUnsupportCmd
Expand Down Expand Up @@ -367,12 +385,7 @@ func (h *DefaultHandle) UDPHandle(s *Server, addr *net.UDPAddr, d *Datagram) err
}
c, err := Dial.Dial("udp", d.Address())
if err != nil {
v, ok := s.TCPUDPAssociate.Get(addr.String())
if ok {
ch := v.(chan byte)
ch <- 0x00
s.TCPUDPAssociate.Delete(addr.String())
}
s.UDPReleasesTCP(addr)
return err
}
// A UDP association terminates when the TCP connection that the UDP
Expand All @@ -386,24 +399,14 @@ func (h *DefaultHandle) UDPHandle(s *Server, addr *net.UDPAddr, d *Datagram) err
log.Printf("Created remote UDP conn for client. client: %#v server: %#v remote: %#v\n", addr.String(), ue.RemoteConn.LocalAddr().String(), d.Address())
}
if err := send(ue, d.Data); err != nil {
v, ok := s.TCPUDPAssociate.Get(ue.ClientAddr.String())
if ok {
ch := v.(chan byte)
ch <- 0x00
s.TCPUDPAssociate.Delete(ue.ClientAddr.String())
}
s.UDPReleasesTCP(ue.ClientAddr)
ue.RemoteConn.Close()
return err
}
s.UDPExchanges.Set(ue.ClientAddr.String(), ue, cache.DefaultExpiration)
go func(ue *UDPExchange) {
defer func() {
v, ok := s.TCPUDPAssociate.Get(ue.ClientAddr.String())
if ok {
ch := v.(chan byte)
ch <- 0x00
s.TCPUDPAssociate.Delete(ue.ClientAddr.String())
}
s.UDPReleasesTCP(ue.ClientAddr)
s.UDPExchanges.Delete(ue.ClientAddr.String())
ue.RemoteConn.Close()
}()
Expand Down
Loading

0 comments on commit 6926f2b

Please sign in to comment.