Skip to content

Commit

Permalink
*: tiny update for the whitelist plugin (pingcap#9271)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao authored and jackysp committed Feb 12, 2019
1 parent c190de3 commit c3f64b2
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 2 deletions.
5 changes: 5 additions & 0 deletions domain/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,11 @@ func (do *Domain) SysSessionPool() *sessionPool {
return do.sysSessionPool
}

// GetEtcdClient returns the etcd client.
func (do *Domain) GetEtcdClient() *clientv3.Client {
return do.etcdClient
}

// LoadPrivilegeLoop create a goroutine loads privilege tables in a loop, it
// should be called only once in BootstrapSession.
func (do *Domain) LoadPrivilegeLoop(ctx sessionctx.Context) error {
Expand Down
14 changes: 14 additions & 0 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
Expand Down Expand Up @@ -376,6 +377,19 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error {
defer sysSessionPool.Put(ctx)
err = dom.PrivilegeHandle().Update(ctx.(sessionctx.Context))
return errors.Trace(err)
case ast.FlushStatus:
dom := domain.GetDomain(e.ctx)
if plugin.Get(plugin.Audit, "ipwhitelist") != nil {
if cli := dom.GetEtcdClient(); cli != nil {
const whitelistKey = "/tidb/plugins/whitelist"
row := cli.KV
_, err := row.Put(context.Background(), whitelistKey, "")
if err != nil {
log.Warn("notify update whitelist failed:", err)
}
return errors.Trace(err)
}
}
}
return nil
}
Expand Down
4 changes: 3 additions & 1 deletion plugin/spi.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"reflect"
"unsafe"

"github.com/pingcap/parser/auth"
"github.com/pingcap/tidb/sessionctx/variable"
)

Expand Down Expand Up @@ -54,7 +55,8 @@ func ExportManifest(m interface{}) *Manifest {
// AuditManifest presents a sub-manifest that every audit plugin must provide.
type AuditManifest struct {
Manifest
NotifyEvent func(ctx context.Context, sctx *variable.SessionVars) error
NotifyEvent func(ctx context.Context, sctx *variable.SessionVars) error
OnConnectionEvent func(ctx context.Context, u *auth.UserIdentity) error
}

// AuthenticationManifest presents a sub-manifest that every audit plugin must provide.
Expand Down
31 changes: 31 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ import (

"github.com/blacktear23/go-proxyprotocol"
"github.com/pingcap/errors"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -306,6 +308,26 @@ func (s *Server) Run() error {
terror.Log(errors.Trace(err))
break
}

for _, p := range plugin.GetByKind(plugin.Audit) {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
host, err := getPeerHost(conn)
if err != nil {
log.Error(err)
terror.Log(conn.Close())
continue
}

err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: host})
if err != nil {
log.Info(err)
terror.Log(conn.Close())
continue
}
}
}

go s.onConn(conn)
}
err := s.listener.Close()
Expand All @@ -318,6 +340,15 @@ func (s *Server) Run() error {
}
}

func getPeerHost(conn net.Conn) (string, error) {
addr := conn.RemoteAddr().String()
host, _, err := net.SplitHostPort(addr)
if err != nil {
return "", errors.Trace(err)
}
return host, nil
}

func (s *Server) shouldStopListener() bool {
select {
case <-s.stopListenerCh:
Expand Down
10 changes: 9 additions & 1 deletion session/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,17 @@ type domainMap struct {
}

func (dm *domainMap) Get(store kv.Storage) (d *domain.Domain, err error) {
key := store.UUID()
dm.mu.Lock()
defer dm.mu.Unlock()

// If this is the only domain instance, and the caller doesn't provide store.
if len(dm.domains) == 1 && store == nil {
for _, r := range dm.domains {
return r, nil
}
}

key := store.UUID()
d = dm.domains[key]
if d != nil {
return
Expand Down

0 comments on commit c3f64b2

Please sign in to comment.