diff --git a/v2/connection.go b/v2/connection.go index 610d91de..68a48bba 100755 --- a/v2/connection.go +++ b/v2/connection.go @@ -80,6 +80,7 @@ type Connection struct { type OracleConnector struct { drv *OracleDriver connectString string + dialer network.DialerContext } type OracleDriver struct { //m sync.Mutex @@ -109,6 +110,7 @@ func (connector *OracleConnector) Connect(ctx context.Context) (driver.Conn, err if err != nil { return nil, err } + conn.connOption.Dialer = connector.dialer err = conn.OpenWithContext(ctx) if err != nil { return nil, err @@ -120,6 +122,10 @@ func (connector *OracleConnector) Driver() driver.Driver { return connector.drv } +func (connector *OracleConnector) Dialer(dialer network.DialerContext) { + connector.dialer = dialer +} + // Open return a new open connection func (drv *OracleDriver) Open(name string) (driver.Conn, error) { conn, err := NewConnection(name) diff --git a/v2/network/connect_option.go b/v2/network/connect_option.go index 2ac8e288..03d42064 100755 --- a/v2/network/connect_option.go +++ b/v2/network/connect_option.go @@ -1,6 +1,7 @@ package network import ( + "context" "errors" "fmt" "net" @@ -39,6 +40,11 @@ type DatabaseInfo struct { AuthType int connStr string } + +type DialerContext interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + type SessionInfo struct { SSLVersion string Timeout time.Duration @@ -48,6 +54,7 @@ type SessionInfo struct { Protocol string SSL bool SSLVerify bool + Dialer DialerContext } type AdvNegoSeviceInfo struct { AuthService []string diff --git a/v2/network/session.go b/v2/network/session.go index f033290c..b1d89dfd 100755 --- a/v2/network/session.go +++ b/v2/network/session.go @@ -10,12 +10,13 @@ import ( "encoding/pem" "errors" "fmt" - "github.com/sijms/go-ora/v2/trace" "net" "reflect" "strings" "time" + "github.com/sijms/go-ora/v2/trace" + "github.com/sijms/go-ora/v2/converters" ) @@ -279,8 +280,11 @@ func (session *Session) Connect(ctx context.Context) error { var connected = false var host *ServerAddr var loop = true - dialer := net.Dialer{ - Timeout: time.Second * session.Context.ConnOption.Timeout, + dialer := connOption.Dialer + if dialer == nil { + dialer = &net.Dialer{ + Timeout: time.Second * session.Context.ConnOption.Timeout, + } } for loop { host = connOption.GetActiveServer(false)