Skip to content

Commit

Permalink
可选项通过参数传入New
Browse files Browse the repository at this point in the history
  • Loading branch information
ouqiang committed Aug 10, 2018
1 parent 796d2e3 commit 63ad7dc
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 43 deletions.
2 changes: 0 additions & 2 deletions delegate.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ type Delegate interface {

var _ Delegate = &DefaultDelegate{}

var defaultHandler = &DefaultDelegate{}

// DefaultDelegate 默认Handler什么也不做
type DefaultDelegate struct {
Delegate
Expand Down
110 changes: 69 additions & 41 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,73 @@ const (

var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n")

//
func makeTunnelRequestLine(addr string) string {
return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr)
}

type options struct {
disableKeepAlive bool
delegate Delegate
transport *http.Transport
}

type Option func(*options)

func WithDisableKeepAlive(disableKeepAlive bool) Option {
return func(opt *options) {
opt.disableKeepAlive = disableKeepAlive
}
}

func WithDelegate(delegate Delegate) Option {
return func(opt *options) {
opt.delegate = delegate
}
}

func WithTransport(t *http.Transport) Option {
return func(opt *options) {
opt.transport = t
}
}

// New 创建proxy实例
func New() *Proxy {
return &Proxy{
Delegate: defaultHandler,
func New(opt ...Option) *Proxy {
opts := &options{}
for _, o := range opt {
o(opts)
}
if opts.delegate == nil {
opts.delegate = &DefaultDelegate{}
}
if opts.transport == nil {
opts.transport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

p := &Proxy{}
p.delegate = opts.delegate
p.transport = opts.transport
p.transport.DisableKeepAlives = opts.disableKeepAlive
p.transport.Proxy = p.delegate.ParentProxy

return p
}

// Proxy 实现了http.Handler接口
type Proxy struct {
DisabledKeepAlive bool
Delegate Delegate
clientConnNum int32
transport http.RoundTripper
delegate Delegate
clientConnNum int32
transport *http.Transport
}

var _ http.Handler = &Proxy{}
Expand All @@ -65,11 +114,11 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
Req: req,
Data: make(map[interface{}]interface{}),
}
p.Delegate.Connect(ctx, rw)
p.delegate.Connect(ctx, rw)
if ctx.abort {
return
}
p.Delegate.Auth(ctx, rw)
p.delegate.Auth(ctx, rw)
if ctx.abort {
return
}
Expand All @@ -80,7 +129,7 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
default:
p.forwardHTTP(ctx, rw)
}
p.Delegate.Finish(ctx)
p.delegate.Finish(ctx)
}

// ClientConnNum 获取客户端连接数
Expand All @@ -90,18 +139,18 @@ func (p *Proxy) ClientConnNum() int32 {

// HTTP转发
func (p *Proxy) forwardHTTP(ctx *Context, rw http.ResponseWriter) {
p.Delegate.BeforeRequest(ctx)
p.delegate.BeforeRequest(ctx)
if ctx.abort {
return
}
removeIssueHeader(ctx.Req.Header)
resp, err := p.roundTripper().RoundTrip(ctx.Req)
p.Delegate.BeforeResponse(ctx, resp, err)
resp, err := p.transport.RoundTrip(ctx.Req)
p.delegate.BeforeResponse(ctx, resp, err)
if ctx.abort {
return
}
if err != nil {
p.Delegate.ErrorLog(fmt.Errorf("HTTP请求错误: [URL: %s], 错误: %s", ctx.Req.URL, err))
p.delegate.ErrorLog(fmt.Errorf("HTTP请求错误: [URL: %s], 错误: %s", ctx.Req.URL, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
Expand All @@ -116,14 +165,14 @@ func (p *Proxy) forwardHTTP(ctx *Context, rw http.ResponseWriter) {
func (p *Proxy) forwardTunnel(ctx *Context, rw http.ResponseWriter) {
clientConn, err := p.hijacker(rw)
if err != nil {
p.Delegate.ErrorLog(err)
p.delegate.ErrorLog(err)
rw.WriteHeader(http.StatusBadGateway)
return
}
defer clientConn.Close()
parentProxyURL, err := p.Delegate.ParentProxy(ctx.Req)
parentProxyURL, err := p.delegate.ParentProxy(ctx.Req)
if err != nil {
p.Delegate.ErrorLog(fmt.Errorf("解析代理地址错误: [%s] %s", ctx.Req.URL.Host, err))
p.delegate.ErrorLog(fmt.Errorf("解析代理地址错误: [%s] %s", ctx.Req.URL.Host, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
Expand All @@ -134,7 +183,7 @@ func (p *Proxy) forwardTunnel(ctx *Context, rw http.ResponseWriter) {

targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout)
if err != nil {
p.Delegate.ErrorLog(fmt.Errorf("隧道转发连接目标服务器失败: [%s] [%s]", ctx.Req.URL.Host, err))
p.delegate.ErrorLog(fmt.Errorf("隧道转发连接目标服务器失败: [%s] [%s]", ctx.Req.URL.Host, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
Expand All @@ -144,7 +193,7 @@ func (p *Proxy) forwardTunnel(ctx *Context, rw http.ResponseWriter) {
if parentProxyURL == nil {
_, err = clientConn.Write(tunnelEstablishedResponseLine)
if err != nil {
p.Delegate.ErrorLog(fmt.Errorf("隧道连接成功,通知客户端错误: %s", err))
p.delegate.ErrorLog(fmt.Errorf("隧道连接成功,通知客户端错误: %s", err))
return
}
} else {
Expand Down Expand Up @@ -182,27 +231,6 @@ func (p *Proxy) hijacker(rw http.ResponseWriter) (net.Conn, error) {
return conn, nil
}

func (p *Proxy) roundTripper() http.RoundTripper {
if p.transport != nil {
return p.transport
}
p.transport = &http.Transport{
Proxy: p.Delegate.ParentProxy,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
DisableKeepAlives: p.DisabledKeepAlive,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}

return p.transport
}

// copyHeader 浅拷贝Header
func copyHeader(dst, src http.Header) {
for k, vv := range src {
Expand Down

0 comments on commit 63ad7dc

Please sign in to comment.