diff --git a/Makefile b/Makefile index 5c6af267..431514f9 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,9 @@ docker-build: integration-test: @go clean -testcache go test -tags integration -v ./test/... - go test -tags integration-all -v ./integration_test/... + go test -tags integration-db_tbl -v ./integration_test/scene/db_tbl/... + go test -tags integration-db -v ./integration_test/scene/db/... + go test -tags integration-tbl -v ./integration_test/scene/tbl/... clean: @rm -rf coverage.txt diff --git a/README.md b/README.md index 5b5f61c2..e8898d9c 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ when using `arana`, user doesn't need to care about the `sharding` details of da [![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](https://github.com/arana-db/arana/blob/master/LICENSE) [![codecov](https://codecov.io/gh/arana-db/arana/branch/master/graph/badge.svg)](https://codecov.io/gh/arana-db/arana) +[![Go Report Card](https://goreportcard.com/badge/github.com/arana-db/arana)](https://goreportcard.com/report/github.com/arana-db/arana) [![Release](https://img.shields.io/github/v/release/arana-db/arana)](https://img.shields.io/github/v/release/arana-db/arana) [![Docker Pulls](https://img.shields.io/docker/pulls/aranadb/arana)](https://img.shields.io/docker/pulls/aranadb/arana) @@ -21,14 +22,14 @@ when using `arana`, user doesn't need to care about the `sharding` details of da ## Introduction | [中文](https://github.com/arana-db/arana/blob/master/README_CN.md) -First, `Arana` is a Cloud Native Database Proxy. It provides transparent data access capabilities, when using `arana`, +First, `Arana` is a Cloud Native Database Proxy. It provides transparent data access capabilities, when using `arana`, user doesn't need to care about the `sharding` details of database, they can use it just like a single `MySQL` database. -`Arana` also provide abilities of `Multi Tenant`, `Distributed transaction`, `Shadow database`, `SQL Audit`, `Data encrypt / decrypt` +`Arana` also provide abilities of `Multi Tenant`, `Distributed transaction`, `Shadow database`, `SQL Audit`, `Data encrypt / decrypt` and so on. Through simple config, user can use these abilities provided by `arana` directly. -Second, `Arana` can also be deployed as a Database mesh sidecar. As a Database mesh sidecar, arana switches data access from -client mode to proxy mode, which greatly optimizes the startup speed of applications. It provides the ability to manage database -traffic, it takes up very little container resources, doesn't affect the performance of application services in the container, but +Second, `Arana` can also be deployed as a Database mesh sidecar. As a Database mesh sidecar, arana switches data access from +client mode to proxy mode, which greatly optimizes the startup speed of applications. It provides the ability to manage database +traffic, it takes up very little container resources, doesn't affect the performance of application services in the container, but provides all the capabilities of proxy. ## Architecture @@ -74,7 +75,7 @@ arana start -c ${configFilePath} ## Contact -Arana Chinese Community Meeting Time: **Every Saturday At 9:00PM GMT+8** +Arana Chinese Community Meeting Time: **Every Saturday At 9:00PM GMT+8** diff --git a/go.mod b/go.mod index bdf82824..32416a0a 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/arana-db/arana go 1.18 require ( - github.com/arana-db/parser v0.2.2 + github.com/arana-db/parser v0.2.3 github.com/bwmarrin/snowflake v0.3.0 github.com/cespare/xxhash/v2 v2.1.2 github.com/dop251/goja v0.0.0-20220422102209-3faab1d8f20e diff --git a/go.sum b/go.sum index 5e9fbd49..72bbdd43 100644 --- a/go.sum +++ b/go.sum @@ -94,8 +94,8 @@ github.com/aliyun/alibaba-cloud-sdk-go v1.61.18/go.mod h1:v8ESoHo4SyHmuB4b1tJqDH github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/arana-db/parser v0.2.2 h1:0ndNzQn6Q82s04V2Lp5+cBXSgO592srsISLjhxZ7vQE= -github.com/arana-db/parser v0.2.2/go.mod h1:/XA29bplweWSEAjgoM557ZCzhBilSawUlHcZFjOeDAc= +github.com/arana-db/parser v0.2.3 h1:zLZcx0/oidlHnw/GZYE78NuvwQkHUv2Xtrm2IwyZasA= +github.com/arana-db/parser v0.2.3/go.mod h1:/XA29bplweWSEAjgoM557ZCzhBilSawUlHcZFjOeDAc= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= diff --git a/pkg/boot/boot.go b/pkg/boot/boot.go index 56126b4d..258dbf29 100644 --- a/pkg/boot/boot.go +++ b/pkg/boot/boot.go @@ -30,7 +30,6 @@ import ( "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/runtime" "github.com/arana-db/arana/pkg/runtime/namespace" - "github.com/arana-db/arana/pkg/runtime/optimize" "github.com/arana-db/arana/pkg/security" "github.com/arana-db/arana/pkg/util/log" ) @@ -129,5 +128,5 @@ func buildNamespace(ctx context.Context, provider Discovery, cluster string) (*n } initCmds = append(initCmds, namespace.UpdateRule(&ru)) - return namespace.New(cluster, optimize.GetOptimizer(), initCmds...), nil + return namespace.New(cluster, initCmds...), nil } diff --git a/pkg/dataset/parallel_test.go b/pkg/dataset/parallel_test.go index 989581d8..a8bc50f2 100644 --- a/pkg/dataset/parallel_test.go +++ b/pkg/dataset/parallel_test.go @@ -187,9 +187,9 @@ func TestParallelDataset_SortBy(t *testing.T) { } var ( - d = generateFakeParallelDataset(ctrl, pairs...) - pojo fakePojo - ids sort.IntSlice + d = generateFakeParallelDataset(ctrl, pairs...) + + ids sort.IntSlice ) for { @@ -214,6 +214,7 @@ func TestParallelDataset_SortBy(t *testing.T) { return } + var pojo fakePojo err = scanPojo(r, &pojo) assert.NoError(t, err) @@ -235,6 +236,8 @@ func TestParallelDataset_SortBy(t *testing.T) { err := d.SetNextN(bingo) assert.NoError(t, err) + var pojo fakePojo + r, err := d.Next() assert.NoError(t, err) err = scanPojo(r, &pojo) diff --git a/pkg/executor/redirect.go b/pkg/executor/redirect.go index 52c3c462..e6829dde 100644 --- a/pkg/executor/redirect.go +++ b/pkg/executor/redirect.go @@ -39,6 +39,7 @@ import ( "github.com/arana-db/arana/pkg/metrics" mysqlErrors "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime" rcontext "github.com/arana-db/arana/pkg/runtime/context" @@ -153,10 +154,20 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re p := parser.New() query := ctx.GetQuery() start := time.Now() - act, err := p.ParseOneStmt(query, "", "") + act, hts, err := p.ParseOneStmtHints(query, "", "") if err != nil { return nil, 0, errors.WithStack(err) } + + var hints []*hint.Hint + for _, next := range hts { + var h *hint.Hint + if h, err = hint.Parse(next); err != nil { + return nil, 0, err + } + hints = append(hints, h) + } + metrics.ParserDuration.Observe(time.Since(start).Seconds()) log.Debugf("ComQuery: %s", query) @@ -172,6 +183,7 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re } ctx.Stmt = &proto.Stmt{ + Hints: hints, StmtNode: act, } diff --git a/pkg/mysql/execute_handle.go b/pkg/mysql/execute_handle.go index 5f28ff2e..f150b5d5 100644 --- a/pkg/mysql/execute_handle.go +++ b/pkg/mysql/execute_handle.go @@ -29,6 +29,7 @@ import ( "github.com/arana-db/arana/pkg/constants/mysql" "github.com/arana-db/arana/pkg/mysql/errors" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/security" "github.com/arana-db/arana/pkg/util/log" ) @@ -263,7 +264,7 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error { PrepareStmt: query, } p := parser.New() - act, err := p.ParseOneStmt(stmt.PrepareStmt, "", "") + act, hts, err := p.ParseOneStmtHints(stmt.PrepareStmt, "", "") if err != nil { log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err) if wErr := c.writeErrorPacketFromError(err); wErr != nil { @@ -271,6 +272,18 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error { return wErr } } + + for _, it := range hts { + var h *hint.Hint + if h, err = hint.Parse(it); err != nil { + if wErr := c.writeErrorPacketFromError(err); wErr != nil { + log.Errorf("Conn %v: Error writing prepared statement error: %v", c, wErr) + return wErr + } + } + stmt.Hints = append(stmt.Hints, h) + } + stmt.StmtNode = act paramsCount := uint16(strings.Count(query, "?")) diff --git a/pkg/mysql/rows.go b/pkg/mysql/rows.go index dff027a4..6fcce638 100644 --- a/pkg/mysql/rows.go +++ b/pkg/mysql/rows.go @@ -217,7 +217,7 @@ func (bi BinaryRow) Scan(dest []proto.Value) error { case 1, 2, 3, 4, 5, 6: dstlen = 8 + 1 + decimals default: - return errors.Errorf("protocol error, illegal decimals architecture.Value %d", field.decimals) + return errors.Errorf("protocol error, illegal decimals architecture.V %d", field.decimals) } val, err = formatBinaryTime(bi.raw[pos:pos+int(num)], dstlen) dest[i] = val @@ -238,7 +238,7 @@ func (bi BinaryRow) Scan(dest []proto.Value) error { case 1, 2, 3, 4, 5, 6: dstlen = 19 + 1 + decimals default: - return errors.Errorf("protocol error, illegal decimals architecture.Value %d", field.decimals) + return errors.Errorf("protocol error, illegal decimals architecture.V %d", field.decimals) } } val, err = formatBinaryDateTime(bi.raw[pos:pos+int(num)], dstlen) diff --git a/pkg/mysql/utils.go b/pkg/mysql/utils.go index 1a8e70e7..8df16774 100644 --- a/pkg/mysql/utils.go +++ b/pkg/mysql/utils.go @@ -1084,7 +1084,7 @@ func convertAssignRows(dest, src interface{}) error { i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetInt(i64) return nil @@ -1096,7 +1096,7 @@ func convertAssignRows(dest, src interface{}) error { u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetUint(u64) return nil @@ -1108,7 +1108,7 @@ func convertAssignRows(dest, src interface{}) error { f64, err := strconv.ParseFloat(s, dv.Type().Bits()) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.V type %T (%q) to a %s: %v", src, s, dv.Kind(), err) } dv.SetFloat(f64) return nil @@ -1126,7 +1126,7 @@ func convertAssignRows(dest, src interface{}) error { } } - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) + return fmt.Errorf("unsupported Scan, storing driver.V type %T into type %T", src, dest) } func cloneBytes(b []byte) []byte { diff --git a/pkg/proto/hint/hint.go b/pkg/proto/hint/hint.go new file mode 100644 index 00000000..5a1adb87 --- /dev/null +++ b/pkg/proto/hint/hint.go @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hint + +import ( + "bufio" + "bytes" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/runtime/misc" +) + +const ( + _ Type = iota + TypeMaster // force route to master node + TypeSlave // force route to slave node + TypeRoute // custom route + TypeFullScan // enable full-scan + TypeDirect // direct route +) + +var _hintTypes = [...]string{ + TypeMaster: "MASTER", + TypeSlave: "SLAVE", + TypeRoute: "ROUTE", + TypeFullScan: "FULLSCAN", + TypeDirect: "DIRECT", +} + +// KeyValue represents a pair of key and value. +type KeyValue struct { + K string // key (optional) + V string // value +} + +// Type represents the type of Hint. +type Type uint8 + +// String returns the display string. +func (tp Type) String() string { + return _hintTypes[tp] +} + +// Hint represents a Hint, a valid Hint should include type and input kv pairs. +// +// Follow the format below: +// - without inputs: YOUR_HINT() +// - with non-keyed inputs: YOUR_HINT(foo,bar,quz) +// - with keyed inputs: YOUR_HINT(x=foo,y=bar,z=quz) +// +type Hint struct { + Type Type + Inputs []KeyValue +} + +// String returns the display string. +func (h Hint) String() string { + var sb strings.Builder + sb.WriteString(h.Type.String()) + + if len(h.Inputs) < 1 { + sb.WriteString("()") + return sb.String() + } + + sb.WriteByte('(') + + writeKv := func(p KeyValue) { + if key := p.K; len(key) > 0 { + sb.WriteString(key) + sb.WriteByte('=') + } + sb.WriteString(p.V) + } + + writeKv(h.Inputs[0]) + for i := 1; i < len(h.Inputs); i++ { + sb.WriteByte(',') + writeKv(h.Inputs[i]) + } + + sb.WriteByte(')') + return sb.String() +} + +// Parse parses Hint from an input string. +func Parse(s string) (*Hint, error) { + var ( + tpStr string + tp Type + ) + + offset := strings.Index(s, "(") + if offset == -1 { + tpStr = s + } else { + tpStr = s[:offset] + } + + for i, v := range _hintTypes { + if strings.EqualFold(tpStr, v) { + tp = Type(i) + break + } + } + + if tp == 0 { + return nil, errors.Errorf("hint: invalid input '%s'", s) + } + + if offset == -1 { + return &Hint{Type: tp}, nil + } + + end := strings.LastIndex(s, ")") + if end == -1 { + return nil, errors.Errorf("hint: invalid input '%s'", s) + } + + s = s[offset+1 : end] + + scanner := bufio.NewScanner(strings.NewReader(s)) + scanner.Split(scanComma) + + var kvs []KeyValue + + for scanner.Scan() { + text := scanner.Text() + + // split kv by '=' + i := strings.Index(text, "=") + if i == -1 { + // omit blank text + if misc.IsBlank(text) { + continue + } + kvs = append(kvs, KeyValue{V: strings.TrimSpace(text)}) + } else { + var ( + k = strings.TrimSpace(text[:i]) + v = strings.TrimSpace(text[i+1:]) + ) + // omit blank key/value + if misc.IsBlank(k) || misc.IsBlank(v) { + continue + } + kvs = append(kvs, KeyValue{K: k, V: v}) + } + } + + if err := scanner.Err(); err != nil { + return nil, errors.Wrapf(err, "hint: invalid input '%s'", s) + } + + return &Hint{Type: tp, Inputs: kvs}, nil +} + +func scanComma(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, ','); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil +} diff --git a/pkg/proto/hint/hint_test.go b/pkg/proto/hint/hint_test.go new file mode 100644 index 00000000..aa812cc6 --- /dev/null +++ b/pkg/proto/hint/hint_test.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package hint + +import ( + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + type tt struct { + input string + output string + pass bool + } + + for _, next := range []tt{ + {"route( foo , bar , qux )", "ROUTE(foo,bar,qux)", true}, + {"master", "MASTER()", true}, + {"slave", "SLAVE()", true}, + {"not_exist_hint(1,2,3)", "", false}, + {"route(,,,)", "ROUTE()", true}, + {"fullscan()", "FULLSCAN()", true}, + {"route(foo=111,bar=222,qux=333,)", "ROUTE(foo=111,bar=222,qux=333)", true}, + } { + t.Run(next.input, func(t *testing.T) { + res, err := Parse(next.input) + if next.pass { + assert.NoError(t, err) + assert.Equal(t, next.output, res.String()) + } else { + assert.Error(t, err) + } + }) + } +} diff --git a/pkg/proto/rule/topology.go b/pkg/proto/rule/topology.go index df15bc5a..e3f77f43 100644 --- a/pkg/proto/rule/topology.go +++ b/pkg/proto/rule/topology.go @@ -18,62 +18,154 @@ package rule import ( + "math" "sort" + "sync" ) // Topology represents the topology of databases and tables. type Topology struct { + mu sync.RWMutex dbRender, tbRender func(int) string - idx map[int][]int + idx sync.Map // map[int][]int } // Len returns the length of database and table. func (to *Topology) Len() (dbLen int, tblLen int) { - dbLen = len(to.idx) - for _, v := range to.idx { - tblLen += len(v) - } + to.idx.Range(func(_, value any) bool { + dbLen++ + tblLen += len(value.([]int)) + return true + }) return } // SetTopology sets the topology. func (to *Topology) SetTopology(db int, tables ...int) { - if to.idx == nil { - to.idx = make(map[int][]int) - } - if len(tables) < 1 { - delete(to.idx, db) + to.idx.Delete(db) return } clone := make([]int, len(tables)) copy(clone, tables) sort.Ints(clone) - to.idx[db] = clone + to.idx.Store(db, clone) } // SetRender sets the database/table name render. func (to *Topology) SetRender(dbRender, tbRender func(int) string) { + to.mu.Lock() to.dbRender, to.tbRender = dbRender, tbRender + to.mu.Unlock() } // Render renders the name of database and table from indexes. func (to *Topology) Render(dbIdx, tblIdx int) (string, string, bool) { + to.mu.RLock() + defer to.mu.RUnlock() + if to.tbRender == nil || to.dbRender == nil { return "", "", false } return to.dbRender(dbIdx), to.tbRender(tblIdx), true } +func (to *Topology) EnumerateDatabases() []string { + to.mu.RLock() + render := to.dbRender + to.mu.RUnlock() + + var keys []string + + to.idx.Range(func(key, _ any) bool { + keys = append(keys, render(key.(int))) + return true + }) + + sort.Strings(keys) + + return keys +} + +func (to *Topology) Enumerate() DatabaseTables { + to.mu.RLock() + dbRender, tbRender := to.dbRender, to.tbRender + to.mu.RUnlock() + + dt := make(DatabaseTables) + to.Each(func(dbIdx, tbIdx int) bool { + d := dbRender(dbIdx) + t := tbRender(tbIdx) + dt[d] = append(dt[d], t) + return true + }) + + return dt +} + // Each enumerates items in current Topology. func (to *Topology) Each(onEach func(dbIdx, tbIdx int) (ok bool)) bool { - for d, v := range to.idx { + done := true + to.idx.Range(func(key, value any) bool { + var ( + d = key.(int) + v = value.([]int) + ) for _, t := range v { if !onEach(d, t) { + done = false return false } } + return true + }) + + return done +} + +func (to *Topology) Smallest() (db, tb string, ok bool) { + to.mu.RLock() + dbRender, tbRender := to.dbRender, to.tbRender + to.mu.RUnlock() + + smallest := [2]int{math.MaxInt64, math.MaxInt64} + to.idx.Range(func(key, value any) bool { + if d := key.(int); d < smallest[0] { + smallest[0] = d + if t := value.([]int); len(t) > 0 { + smallest[1] = t[0] + } + } + return true + }) + + if smallest[0] != math.MaxInt64 || smallest[1] != math.MaxInt64 { + db, tb, ok = dbRender(smallest[0]), tbRender(smallest[1]), true + } + + return +} + +func (to *Topology) Largest() (db, tb string, ok bool) { + to.mu.RLock() + dbRender, tbRender := to.dbRender, to.tbRender + to.mu.RUnlock() + + largest := [2]int{math.MinInt64, math.MinInt64} + to.idx.Range(func(key, value any) bool { + if d := key.(int); d > largest[0] { + largest[0] = d + if t := value.([]int); len(t) > 0 { + largest[1] = t[len(t)-1] + } + } + return true + }) + + if largest[0] != math.MinInt64 || largest[1] != math.MinInt64 { + db, tb, ok = dbRender(largest[0]), tbRender(largest[1]), true } - return true + + return } diff --git a/pkg/proto/rule/topology_test.go b/pkg/proto/rule/topology_test.go index 637f438f..62604e16 100644 --- a/pkg/proto/rule/topology_test.go +++ b/pkg/proto/rule/topology_test.go @@ -33,24 +33,18 @@ func TestLen(t *testing.T) { assert.Equal(t, 6, tblLen) } -func TestSetTopologyForIdxNil(t *testing.T) { - topology := &Topology{ - idx: nil, - } +func TestSetTopology(t *testing.T) { + var topology Topology + topology.SetTopology(2, 2, 3, 4) - for each := range topology.idx { - assert.Equal(t, 2, each) - assert.Equal(t, 3, len(topology.idx[each])) - } dbLen, tblLen := topology.Len() assert.Equal(t, 1, dbLen) assert.Equal(t, 3, tblLen) } -func TestSetTopologyForIdxNotNil(t *testing.T) { - topology := &Topology{ - idx: map[int][]int{0: []int{1, 2, 3}}, - } +func TestSetTopologyNoConflict(t *testing.T) { + var topology Topology + topology.SetTopology(0, 1, 2, 3) topology.SetTopology(1, 4, 5, 6) dbLen, tblLen := topology.Len() assert.Equal(t, 2, dbLen) @@ -58,14 +52,13 @@ func TestSetTopologyForIdxNotNil(t *testing.T) { } func TestSetTopologyForTablesLessThanOne(t *testing.T) { - topology := &Topology{ - idx: map[int][]int{0: []int{1, 2, 3}, 1: []int{4, 5, 6}}, - } + var topology Topology + + topology.SetTopology(0, 1, 2, 3) + topology.SetTopology(1, 4, 5, 6) + topology.SetTopology(1) - for each := range topology.idx { - assert.Equal(t, 0, each) - assert.Equal(t, 3, len(topology.idx[each])) - } + dbLen, tblLen := topology.Len() assert.Equal(t, 1, dbLen) assert.Equal(t, 3, tblLen) @@ -104,6 +97,36 @@ func TestTopology_Each(t *testing.T) { t.Logf("on each: %d,%d\n", dbIdx, tbIdx) return true }) + + assert.False(t, topology.Each(func(dbIdx, tbIdx int) bool { + return false + })) +} + +func TestTopology_Enumerate(t *testing.T) { + topology := createTopology() + shards := topology.Enumerate() + assert.Greater(t, shards.Len(), 0) +} + +func TestTopology_EnumerateDatabases(t *testing.T) { + topology := createTopology() + dbs := topology.EnumerateDatabases() + assert.Greater(t, len(dbs), 0) +} + +func TestTopology_Largest(t *testing.T) { + topology := createTopology() + db, tb, ok := topology.Largest() + assert.True(t, ok) + t.Logf("largest: %s.%s\n", db, tb) +} + +func TestTopology_Smallest(t *testing.T) { + topology := createTopology() + db, tb, ok := topology.Smallest() + assert.True(t, ok) + t.Logf("smallest: %s.%s\n", db, tb) } func createTopology() *Topology { @@ -114,7 +137,10 @@ func createTopology() *Topology { tbRender: func(i int) string { return fmt.Sprintf("%s:%d", "tbRender", i) }, - idx: map[int][]int{0: []int{1, 2, 3}, 1: []int{4, 5, 6}}, } + + result.SetTopology(0, 1, 2, 3) + result.SetTopology(1, 4, 5, 6) + return result } diff --git a/pkg/proto/runtime.go b/pkg/proto/runtime.go index db87fb47..239031a8 100644 --- a/pkg/proto/runtime.go +++ b/pkg/proto/runtime.go @@ -24,10 +24,6 @@ import ( "time" ) -import ( - "github.com/arana-db/parser/ast" -) - const ( PlanTypeQuery PlanType = iota // QUERY PlanTypeExec // EXEC @@ -56,7 +52,7 @@ type ( // Optimizer represents a sql statement optimizer which can be used to create QueryPlan or ExecPlan. Optimizer interface { // Optimize optimizes the sql with arguments then returns a Plan. - Optimize(ctx context.Context, conn VConn, stmt ast.StmtNode, args ...interface{}) (Plan, error) + Optimize(ctx context.Context) (Plan, error) } // Weight represents the read/write weight info. @@ -114,7 +110,9 @@ type ( Rollback(ctx context.Context) (Result, uint16, error) } + // SchemaLoader represents a schema discovery. SchemaLoader interface { + // Load loads the schema. Load(ctx context.Context, conn VConn, schema string, tables []string) map[string]*TableMetadata } ) diff --git a/pkg/proto/schema_manager/loader_test.go b/pkg/proto/schema_manager/loader_test.go index 53abbb7f..bef3bbeb 100644 --- a/pkg/proto/schema_manager/loader_test.go +++ b/pkg/proto/schema_manager/loader_test.go @@ -15,7 +15,7 @@ * limitations under the License. */ -package schema_manager +package schema_manager_test import ( "context" @@ -25,6 +25,7 @@ import ( import ( "github.com/arana-db/arana/pkg/config" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/schema_manager" "github.com/arana-db/arana/pkg/runtime" "github.com/arana-db/arana/pkg/runtime/namespace" ) @@ -46,7 +47,7 @@ func TestLoader(t *testing.T) { cmds := make([]namespace.Command, 0) cmds = append(cmds, namespace.UpsertDB(groupName, runtime.NewAtomDB(node))) namespaceName := "dongjianhui" - ns := namespace.New(namespaceName, nil, cmds...) + ns := namespace.New(namespaceName, cmds...) namespace.Register(ns) rt, err := runtime.Load(namespaceName) if err != nil { @@ -54,7 +55,7 @@ func TestLoader(t *testing.T) { } schemeName := "employees" tableName := "employees" - s := NewSimpleSchemaLoader() + s := schema_manager.NewSimpleSchemaLoader() s.Load(context.Background(), rt.(proto.VConn), schemeName, []string{tableName}) } diff --git a/pkg/proto/stmt.go b/pkg/proto/stmt.go index 4ea9f5a9..d37c14e1 100644 --- a/pkg/proto/stmt.go +++ b/pkg/proto/stmt.go @@ -21,6 +21,10 @@ import ( "github.com/arana-db/parser/ast" ) +import ( + "github.com/arana-db/arana/pkg/proto/hint" +) + // Stmt is a buffer used for store prepare statement metadata. type Stmt struct { StatementID uint32 @@ -29,5 +33,6 @@ type Stmt struct { ParamsType []int32 ColumnNames []string BindVars map[string]interface{} + Hints []*hint.Hint StmtNode ast.StmtNode } diff --git a/pkg/runtime/ast/alter_table.go b/pkg/runtime/ast/alter_table.go index a60f1596..eab6fd8d 100644 --- a/pkg/runtime/ast/alter_table.go +++ b/pkg/runtime/ast/alter_table.go @@ -179,5 +179,5 @@ func (at *AlterTableStatement) CntParams() int { } func (at *AlterTableStatement) Mode() SQLType { - return SalterTable + return SQLTypeAlterTable } diff --git a/pkg/runtime/ast/ast.go b/pkg/runtime/ast/ast.go index 2ab3da2e..78b09397 100644 --- a/pkg/runtime/ast/ast.go +++ b/pkg/runtime/ast/ast.go @@ -34,6 +34,7 @@ import ( ) import ( + "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/runtime/cmp" "github.com/arana-db/arana/pkg/runtime/logical" ) @@ -676,28 +677,46 @@ func convInsertColumns(columnNames []*ast.ColumnName) []string { } // Parse parses the SQL string to Statement. -func Parse(sql string, options ...ParseOption) (Statement, error) { +func Parse(sql string, options ...ParseOption) ([]*hint.Hint, Statement, error) { var o parseOption for _, it := range options { it(&o) } p := parser.New() - s, err := p.ParseOneStmt(sql, o.charset, o.collation) + s, hintStrs, err := p.ParseOneStmtHints(sql, o.charset, o.collation) if err != nil { - return nil, err + return nil, nil, err } - return FromStmtNode(s) + stmt, err := FromStmtNode(s) + if err != nil { + return nil, nil, err + } + + if len(hintStrs) < 1 { + return nil, stmt, nil + } + + hints := make([]*hint.Hint, 0, len(hintStrs)) + for _, it := range hintStrs { + var h *hint.Hint + if h, err = hint.Parse(it); err != nil { + return nil, nil, errors.WithStack(err) + } + hints = append(hints, h) + } + + return hints, stmt, nil } // MustParse parses the SQL string to Statement, panic if failed. -func MustParse(sql string) Statement { - stmt, err := Parse(sql) +func MustParse(sql string) ([]*hint.Hint, Statement) { + hints, stmt, err := Parse(sql) if err != nil { panic(err.Error()) } - return stmt + return hints, stmt } type convCtx struct { diff --git a/pkg/runtime/ast/ast_test.go b/pkg/runtime/ast/ast_test.go index 17dead10..a5c3f0e4 100644 --- a/pkg/runtime/ast/ast_test.go +++ b/pkg/runtime/ast/ast_test.go @@ -46,39 +46,39 @@ func TestParse(t *testing.T) { "select * from student where uid = !0", } { t.Run(sql, func(t *testing.T) { - stmt, err = Parse(sql) + _, stmt, err = Parse(sql) assert.NoError(t, err) t.Log("stmt:", stmt) }) } // 1. select statement - stmt, err = Parse("select * from student as foo where `name` = if(1>2, 1, 2) order by age") + _, stmt, err = Parse("select * from student as foo where `name` = if(1>2, 1, 2) order by age") assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", stmt) // 2. delete statement - deleteStmt, err := Parse("delete from student as foo where `name` = if(1>2, 1, 2)") + _, deleteStmt, err := Parse("delete from student as foo where `name` = if(1>2, 1, 2)") assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", deleteStmt) // 3. insert statements - insertStmtWithSetClause, err := Parse("insert into sink set a=77, b='88'") + _, insertStmtWithSetClause, err := Parse("insert into sink set a=77, b='88'") assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", insertStmtWithSetClause) - insertStmtWithValues, err := Parse("insert into sink values(1, '2')") + _, insertStmtWithValues, err := Parse("insert into sink values(1, '2')") assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", insertStmtWithValues) - insertStmtWithOnDuplicateUpdates, err := Parse( + _, insertStmtWithOnDuplicateUpdates, err := Parse( "insert into sink (a, b) values(1, '2') on duplicate key update a=a+1", ) assert.NoError(t, err, "parse+conv ast failed") t.Logf("stmt:%+v", insertStmtWithOnDuplicateUpdates) // 4. update statement - updateStmt, err := Parse( + _, updateStmt, err := Parse( "update source set a=a+1, b=b+2 where a>1 order by a limit 5", ) assert.NoError(t, err, "parse+conv ast failed") @@ -98,7 +98,7 @@ func TestParse_UnionStmt(t *testing.T) { {"select id,uid,name,nickname from student where uid in (?,?,?) union all select id,uid,name,nickname from tb_user where uid in (?,?,?)", "SELECT `id`,`uid`,`name`,`nickname` FROM `student` WHERE `uid` IN (?,?,?) UNION ALL SELECT `id`,`uid`,`name`,`nickname` FROM `tb_user` WHERE `uid` IN (?,?,?)"}, } { t.Run(next.input, func(t *testing.T) { - stmt, err := Parse(next.input) + _, stmt, err := Parse(next.input) assert.NoError(t, err, "should parse ok") assert.IsType(t, (*UnionSelectStatement)(nil), stmt, "should be union statement") @@ -162,7 +162,7 @@ func TestParse_SelectStmt(t *testing.T) { {"select null as pkid", "SELECT NULL AS `pkid`"}, } { t.Run(next.input, func(t *testing.T) { - stmt, err := Parse(next.input) + _, stmt, err := Parse(next.input) assert.NoError(t, err, "should parse ok") assert.IsType(t, (*SelectStatement)(nil), stmt, "should be select statement") @@ -185,7 +185,7 @@ func TestParse_DeleteStmt(t *testing.T) { {"delete low_priority quick ignore from student where id = 1", "DELETE LOW_PRIORITY QUICK IGNORE FROM `student` WHERE `id` = 1"}, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsType(t, (*DeleteStatement)(nil), stmt, "should be delete statement") @@ -206,7 +206,7 @@ func TestParse_DescribeStatement(t *testing.T) { {"desc foobar", "DESC `foobar`"}, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsType(t, (*DescribeStatement)(nil), stmt, "should be describe statement") @@ -243,7 +243,7 @@ func TestParse_ShowStatement(t *testing.T) { {"show create table `foo`", (*ShowCreate)(nil), "SHOW CREATE TABLE `foo`"}, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, it.expectTyp, stmt, "should be %T", it.expectTyp) @@ -256,7 +256,7 @@ func TestParse_ShowStatement(t *testing.T) { } func TestParse_ExplainStmt(t *testing.T) { - stmt, err := Parse("explain select * from student where uid = 1") + _, stmt, err := Parse("explain select * from student where uid = 1") assert.NoError(t, err) assert.IsType(t, (*ExplainStatement)(nil), stmt) s := MustRestoreToString(RestoreDefault, stmt) @@ -291,7 +291,7 @@ func TestParseMore(t *testing.T) { for _, sql := range tbls { t.Run(sql, func(t *testing.T) { - _, err := Parse(sql) + _, _, err := Parse(sql) assert.NoError(t, err) }) } @@ -308,7 +308,7 @@ func TestParse_UpdateStmt(t *testing.T) { {"update low_priority student set nickname = ? where id = 1 limit 1", "UPDATE LOW_PRIORITY `student` SET `nickname` = ? WHERE `id` = 1 LIMIT 1"}, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, (*UpdateStatement)(nil), stmt, "should be update statement") @@ -342,7 +342,7 @@ func TestParse_InsertStmt(t *testing.T) { }, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, (*InsertStatement)(nil), stmt, "should be insert statement") @@ -371,7 +371,7 @@ func TestParse_InsertStmt(t *testing.T) { }, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, (*InsertSelectStatement)(nil), stmt, "should be insert-select statement") @@ -384,7 +384,7 @@ func TestParse_InsertStmt(t *testing.T) { } func TestRestoreCount(t *testing.T) { - stmt := MustParse("select count(1)") + _, stmt := MustParse("select count(1)") sel := stmt.(*SelectStatement) var sb strings.Builder _ = sel.Restore(RestoreDefault, &sb, nil) @@ -392,7 +392,7 @@ func TestRestoreCount(t *testing.T) { } func TestQuote(t *testing.T) { - stmt := MustParse("select `a``bc`") + _, stmt := MustParse("select `a``bc`") sel := stmt.(*SelectStatement) var sb strings.Builder _ = sel.Restore(RestoreDefault, &sb, nil) @@ -437,7 +437,7 @@ func TestParse_AlterTableStmt(t *testing.T) { }, } { t.Run(it.input, func(t *testing.T) { - stmt, err := Parse(it.input) + _, stmt, err := Parse(it.input) assert.NoError(t, err) assert.IsTypef(t, (*AlterTableStatement)(nil), stmt, "should be alter table statement") @@ -450,7 +450,7 @@ func TestParse_AlterTableStmt(t *testing.T) { } func TestParse_DescStmt(t *testing.T) { - stmt := MustParse("desc student id") + _, stmt := MustParse("desc student id") // In MySQL, the case of "desc student 'id'" will be parsed successfully, // but in arana, it will get an error by tidb parser. desc := stmt.(*DescribeStatement) diff --git a/pkg/runtime/ast/create_index.go b/pkg/runtime/ast/create_index.go index 6f6a2f56..050713bc 100644 --- a/pkg/runtime/ast/create_index.go +++ b/pkg/runtime/ast/create_index.go @@ -17,7 +17,9 @@ package ast -import "strings" +import ( + "strings" +) var ( _ Statement = (*CreateIndexStatement)(nil) @@ -63,5 +65,5 @@ func (c *CreateIndexStatement) Validate() error { } func (c *CreateIndexStatement) Mode() SQLType { - return CreateIndex + return SQLTypeCreateIndex } diff --git a/pkg/runtime/ast/delete.go b/pkg/runtime/ast/delete.go index 90987637..b8b38f05 100644 --- a/pkg/runtime/ast/delete.go +++ b/pkg/runtime/ast/delete.go @@ -109,7 +109,7 @@ func (ds *DeleteStatement) CntParams() int { } func (ds *DeleteStatement) Mode() SQLType { - return Sdelete + return SQLTypeDelete } func (ds *DeleteStatement) IsLowPriority() bool { diff --git a/pkg/runtime/ast/describe.go b/pkg/runtime/ast/describe.go index ab25ae27..3523ff48 100644 --- a/pkg/runtime/ast/describe.go +++ b/pkg/runtime/ast/describe.go @@ -59,7 +59,7 @@ func (d *DescribeStatement) CntParams() int { } func (d *DescribeStatement) Mode() SQLType { - return Squery + return SQLTypeDescribe } // ExplainStatement represents mysql explain statement. see https://dev.mysql.com/doc/refman/8.0/en/explain.html @@ -88,5 +88,5 @@ func (e *ExplainStatement) CntParams() int { } func (e *ExplainStatement) Mode() SQLType { - return Squery + return SQLTypeSelect } diff --git a/pkg/runtime/ast/drop_index.go b/pkg/runtime/ast/drop_index.go index 7dea83f8..58918a21 100644 --- a/pkg/runtime/ast/drop_index.go +++ b/pkg/runtime/ast/drop_index.go @@ -54,5 +54,5 @@ func (d *DropIndexStatement) Validate() error { } func (d *DropIndexStatement) Mode() SQLType { - return DropIndex + return SQLTypeDropIndex } diff --git a/pkg/runtime/ast/drop_table.go b/pkg/runtime/ast/drop_table.go index 1bf3b63c..1b277871 100644 --- a/pkg/runtime/ast/drop_table.go +++ b/pkg/runtime/ast/drop_table.go @@ -55,5 +55,5 @@ func (d DropTableStatement) Validate() error { } func (d DropTableStatement) Mode() SQLType { - return SdropTable + return SQLTypeDropTable } diff --git a/pkg/runtime/ast/insert.go b/pkg/runtime/ast/insert.go index b2cad95b..6a70cea5 100644 --- a/pkg/runtime/ast/insert.go +++ b/pkg/runtime/ast/insert.go @@ -159,7 +159,7 @@ func (r *ReplaceStatement) Values() [][]ExpressionNode { } func (r *ReplaceStatement) Mode() SQLType { - return Sreplace + return SQLTypeReplace } func (r *ReplaceStatement) CntParams() int { @@ -343,7 +343,7 @@ func (is *InsertStatement) CntParams() int { } func (is *InsertStatement) Mode() SQLType { - return Sinsert + return SQLTypeInsert } type ReplaceSelectStatement struct { @@ -369,7 +369,7 @@ func (r *ReplaceSelectStatement) CntParams() int { } func (r *ReplaceSelectStatement) Mode() SQLType { - return Sreplace + return SQLTypeReplace } type InsertSelectStatement struct { @@ -461,5 +461,5 @@ func (is *InsertSelectStatement) CntParams() int { } func (is *InsertSelectStatement) Mode() SQLType { - return Sinsert + return SQLTypeInsertSelect } diff --git a/pkg/runtime/ast/proto.go b/pkg/runtime/ast/proto.go index 7fb23d09..5e5dbf38 100644 --- a/pkg/runtime/ast/proto.go +++ b/pkg/runtime/ast/proto.go @@ -22,18 +22,28 @@ import ( ) const ( - _ SQLType = iota - Squery // QUERY - Sdelete // DELETE - Supdate // UPDATE - Sinsert // INSERT - Sreplace // REPLACE - Struncate // TRUNCATE - SdropTable // DROP TABLE - SalterTable // ALTER TABLE - DropIndex // DROP INDEX - CreateIndex // CREATE INDEX - DropTrigger // DROP TRIGGER + _ SQLType = iota + SQLTypeSelect // SELECT + SQLTypeDelete // DELETE + SQLTypeUpdate // UPDATE + SQLTypeInsert // INSERT + SQLTypeInsertSelect // INSERT SELECT + SQLTypeReplace // REPLACE + SQLTypeTruncate // TRUNCATE + SQLTypeDropTable // DROP TABLE + SQLTypeAlterTable // ALTER TABLE + SQLTypeDropIndex // DROP INDEX + SQLTypeShowDatabases // SHOW DATABASES + SQLTypeShowTables // SHOW TABLES + SQLTypeShowOpenTables // SHOW OPEN TABLES + SQLTypeShowIndex // SHOW INDEX + SQLTypeShowColumns // SHOW COLUMNS + SQLTypeShowCreate // SHOW CREATE + SQLTypeShowVariables // SHOW VARIABLES + SQLTypeDescribe // DESCRIBE + SQLTypeUnion // UNION + SQLTypeDropTrigger // DROP TRIGGER + SQLTypeCreateIndex // CREATE INDEX ) type RestoreFlag uint32 @@ -48,16 +58,27 @@ type Restorer interface { } var _sqlTypeNames = [...]string{ - Squery: "QUERY", - Sdelete: "DELETE", - Supdate: "UPDATE", - Sinsert: "INSERT", - Sreplace: "REPLACE", - Struncate: "TRUNCATE", - SdropTable: "DROP TABLE", - SalterTable: "ALTER TABLE", - DropIndex: "DROP INDEX", - CreateIndex: "CREATE INDEX", + SQLTypeSelect: "SELECT", + SQLTypeDelete: "DELETE", + SQLTypeUpdate: "UPDATE", + SQLTypeInsert: "INSERT", + SQLTypeInsertSelect: "INSERT SELECT", + SQLTypeReplace: "REPLACE", + SQLTypeTruncate: "TRUNCATE", + SQLTypeDropTable: "DROP TABLE", + SQLTypeAlterTable: "ALTER TABLE", + SQLTypeDropIndex: "DROP INDEX", + SQLTypeShowDatabases: "SHOW DATABASES", + SQLTypeShowTables: "SHOW TABLES", + SQLTypeShowOpenTables: "SHOW OPEN TABLES", + SQLTypeShowIndex: "SHOW INDEX", + SQLTypeShowColumns: "SHOW COLUMNS", + SQLTypeShowCreate: "SHOW CREATE", + SQLTypeShowVariables: "SHOW VARIABLES", + SQLTypeDescribe: "DESCRIBE", + SQLTypeUnion: "UNION", + SQLTypeDropTrigger: "DROP TRIGGER", + SQLTypeCreateIndex: "CREATE INDEX", } // SQLType represents the type of SQL. diff --git a/pkg/runtime/ast/select.go b/pkg/runtime/ast/select.go index af0ebff1..4366bfff 100644 --- a/pkg/runtime/ast/select.go +++ b/pkg/runtime/ast/select.go @@ -252,7 +252,7 @@ func (ss *SelectStatement) Validate() error { } func (ss *SelectStatement) Mode() SQLType { - return Squery + return SQLTypeSelect } func (ss *SelectStatement) CntParams() int { diff --git a/pkg/runtime/ast/show.go b/pkg/runtime/ast/show.go index a8b07cfb..cda06524 100644 --- a/pkg/runtime/ast/show.go +++ b/pkg/runtime/ast/show.go @@ -28,6 +28,7 @@ import ( var ( _ Statement = (*ShowTables)(nil) + _ Statement = (*ShowOpenTables)(nil) _ Statement = (*ShowCreate)(nil) _ Statement = (*ShowDatabases)(nil) _ Statement = (*ShowColumns)(nil) @@ -69,14 +70,14 @@ func (bs *baseShow) CntParams() int { return 0 } -func (bs *baseShow) Mode() SQLType { - return Squery -} - type ShowDatabases struct { *baseShow } +func (s ShowDatabases) Mode() SQLType { + return SQLTypeShowDatabases +} + func (s ShowDatabases) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { sb.WriteString("SHOW DATABASES") if err := s.baseShow.Restore(flag, sb, args); err != nil { @@ -93,6 +94,10 @@ type ShowTables struct { *baseShow } +func (s ShowTables) Mode() SQLType { + return SQLTypeShowTables +} + func (s ShowTables) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { sb.WriteString("SHOW TABLES") if err := s.baseShow.Restore(flag, sb, args); err != nil { @@ -109,6 +114,10 @@ type ShowOpenTables struct { *baseShow } +func (s ShowOpenTables) Mode() SQLType { + return SQLTypeShowOpenTables +} + func (s ShowOpenTables) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { sb.WriteString("SHOW OPEN TABLES") if err := s.baseShow.Restore(flag, sb, args); err != nil { @@ -189,7 +198,7 @@ func (s *ShowCreate) CntParams() int { } func (s *ShowCreate) Mode() SQLType { - return Squery + return SQLTypeShowCreate } type ShowIndex struct { @@ -231,7 +240,7 @@ func (s *ShowIndex) CntParams() int { } func (s *ShowIndex) Mode() SQLType { - return Squery + return SQLTypeShowIndex } type showColumnsFlag uint8 @@ -304,7 +313,7 @@ func (sh *ShowColumns) CntParams() int { } func (sh *ShowColumns) Mode() SQLType { - return Squery + return SQLTypeShowColumns } func (sh *ShowColumns) Full() bool { @@ -364,5 +373,5 @@ func (s *ShowVariables) CntParams() int { } func (s *ShowVariables) Mode() SQLType { - return Squery + return SQLTypeShowVariables } diff --git a/pkg/runtime/ast/trigger.go b/pkg/runtime/ast/trigger.go index 177029cd..31c5b40b 100644 --- a/pkg/runtime/ast/trigger.go +++ b/pkg/runtime/ast/trigger.go @@ -35,7 +35,7 @@ func (d DropTriggerStatement) CntParams() int { } func (d DropTriggerStatement) Mode() SQLType { - return DropTrigger + return SQLTypeDropTrigger } func (d DropTriggerStatement) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { diff --git a/pkg/runtime/ast/truncate.go b/pkg/runtime/ast/truncate.go index aa5ea87c..573b634c 100644 --- a/pkg/runtime/ast/truncate.go +++ b/pkg/runtime/ast/truncate.go @@ -64,5 +64,5 @@ func (stmt *TruncateStatement) CntParams() int { } func (stmt *TruncateStatement) Mode() SQLType { - return Struncate + return SQLTypeTruncate } diff --git a/pkg/runtime/ast/union.go b/pkg/runtime/ast/union.go index 57ce0cbc..61a7b07f 100644 --- a/pkg/runtime/ast/union.go +++ b/pkg/runtime/ast/union.go @@ -112,7 +112,7 @@ func (u *UnionSelectStatement) OrderBy() OrderByNode { } func (u *UnionSelectStatement) Mode() SQLType { - return Squery + return SQLTypeUnion } func (u *UnionSelectStatement) First() *SelectStatement { diff --git a/pkg/runtime/ast/update.go b/pkg/runtime/ast/update.go index 93c855f3..26a2038b 100644 --- a/pkg/runtime/ast/update.go +++ b/pkg/runtime/ast/update.go @@ -143,5 +143,5 @@ func (u *UpdateStatement) CntParams() int { } func (u *UpdateStatement) Mode() SQLType { - return Supdate + return SQLTypeUpdate } diff --git a/pkg/runtime/context/context.go b/pkg/runtime/context/context.go index 9596dc45..55925e35 100644 --- a/pkg/runtime/context/context.go +++ b/pkg/runtime/context/context.go @@ -23,7 +23,6 @@ import ( import ( "github.com/arana-db/arana/pkg/proto" - "github.com/arana-db/arana/pkg/proto/rule" ) const ( @@ -34,7 +33,6 @@ const ( type ( keyFlag struct{} - keyRule struct{} keySequence struct{} keySql struct{} keyNodeLabel struct{} @@ -71,11 +69,6 @@ func WithDBGroup(ctx context.Context, group string) context.Context { return context.WithValue(ctx, keyDefaultDBGroup{}, group) } -// WithRule binds a rule. -func WithRule(ctx context.Context, ru *rule.Rule) context.Context { - return context.WithValue(ctx, keyRule{}, ru) -} - func WithSchema(ctx context.Context, data string) context.Context { return context.WithValue(ctx, keySchema{}, data) } @@ -122,15 +115,6 @@ func DBGroup(ctx context.Context) string { return db } -// Rule extracts the rule. -func Rule(ctx context.Context) *rule.Rule { - ru, ok := ctx.Value(keyRule{}).(*rule.Rule) - if !ok { - return nil - } - return ru -} - // IsRead returns true if this is a read operation func IsRead(ctx context.Context) bool { return hasFlag(ctx, _flagRead) diff --git a/pkg/runtime/function/function_test.go b/pkg/runtime/function/function_test.go index 64ff195d..b17326a9 100644 --- a/pkg/runtime/function/function_test.go +++ b/pkg/runtime/function/function_test.go @@ -104,7 +104,7 @@ func BenchmarkEval(b *testing.B) { } func mustGetMathAtom() *ast.MathExpressionAtom { - stmt, err := ast.Parse("select * from t where a = 1 + if(?,1,0)") + _, stmt, err := ast.Parse("select * from t where a = 1 + if(?,1,0)") if err != nil { panic(err.Error()) } diff --git a/pkg/runtime/namespace/namespace.go b/pkg/runtime/namespace/namespace.go index fe0b08aa..4962729c 100644 --- a/pkg/runtime/namespace/namespace.go +++ b/pkg/runtime/namespace/namespace.go @@ -76,8 +76,7 @@ type ( name string // the name of Namespace - rule atomic.Value // *rule.Rule - optimizer proto.Optimizer + rule atomic.Value // *rule.Rule // datasource map, eg: employee_0001 -> [mysql-a,mysql-b,mysql-c], ... employee_0007 -> [mysql-x,mysql-y,mysql-z] dss atomic.Value // map[string][]proto.DB @@ -91,12 +90,11 @@ type ( ) // New creates a Namespace. -func New(name string, optimizer proto.Optimizer, commands ...Command) *Namespace { +func New(name string, commands ...Command) *Namespace { ns := &Namespace{ - name: name, - optimizer: optimizer, - cmds: make(chan Command, 1), - done: make(chan struct{}), + name: name, + cmds: make(chan Command, 1), + done: make(chan struct{}), } ns.dss.Store(make(map[string][]proto.DB)) // init empty map ns.rule.Store(&rule.Rule{}) // init empty rule @@ -166,11 +164,6 @@ func (ns *Namespace) DB(ctx context.Context, group string) proto.DB { return exist[target] } -// Optimizer returns the optimizer. -func (ns *Namespace) Optimizer() proto.Optimizer { - return ns.optimizer -} - // Rule returns the sharding rule. func (ns *Namespace) Rule() *rule.Rule { ru, ok := ns.rule.Load().(*rule.Rule) diff --git a/pkg/runtime/namespace/namespace_test.go b/pkg/runtime/namespace/namespace_test.go index d579c5ec..5fb72f34 100644 --- a/pkg/runtime/namespace/namespace_test.go +++ b/pkg/runtime/namespace/namespace_test.go @@ -44,11 +44,7 @@ func TestRegister(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - opt := testdata.NewMockOptimizer(ctrl) - - const ( - name = "employees" - ) + const name = "employees" getDB := func(i int) proto.DB { db := testdata.NewMockDB(ctrl) @@ -58,7 +54,7 @@ func TestRegister(t *testing.T) { return db } - err := Register(New(name, opt, UpsertDB(getGroup(0), getDB(1)))) + err := Register(New(name, UpsertDB(getGroup(0), getDB(1)))) assert.NoError(t, err, "should register namespace ok") defer func() { @@ -89,8 +85,6 @@ func TestGetDBByWeight(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - opt := testdata.NewMockOptimizer(ctrl) - const ( name = "account" ) @@ -104,7 +98,7 @@ func TestGetDBByWeight(t *testing.T) { } // when doing read operation, db 3 is the max // when doing write operation, db 2 is the max - err := Register(New(name, opt, + err := Register(New(name, UpsertDB(getGroup(0), getDB(1, 9, 1)), UpsertDB(getGroup(0), getDB(2, 10, 5)), UpsertDB(getGroup(0), getDB(3, 3, 10)), diff --git a/pkg/runtime/optimize/alter_table.go b/pkg/runtime/optimize/alter_table.go new file mode 100644 index 00000000..f91e2fab --- /dev/null +++ b/pkg/runtime/optimize/alter_table.go @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeAlterTable, optimizeAlterTable) +} + +func optimizeAlterTable(_ context.Context, o *optimizer) (proto.Plan, error) { + var ( + stmt = o.stmt.(*ast.AlterTableStatement) + ret = plan.NewAlterTablePlan(stmt) + table = stmt.Table + vt *rule.VTable + ok bool + ) + ret.BindArgs(o.args) + + // non-sharding update + if vt, ok = o.rule.VTable(table.Suffix()); !ok { + return ret, nil + } + + //TODO alter table table or column to new name , should update sharding info + + // exit if full-scan is disabled + if !vt.AllowFullScan() { + return nil, errDenyFullScan + } + + // sharding + shards := rule.DatabaseTables{} + topology := vt.Topology() + topology.Each(func(dbIdx, tbIdx int) bool { + if d, t, ok := topology.Render(dbIdx, tbIdx); ok { + shards[d] = append(shards[d], t) + } + return true + }) + ret.Shards = shards + return ret, nil +} diff --git a/pkg/runtime/optimize/create_index.go b/pkg/runtime/optimize/create_index.go new file mode 100644 index 00000000..2ce8ddc0 --- /dev/null +++ b/pkg/runtime/optimize/create_index.go @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeCreateIndex, optimizeCreateIndex) +} + +func optimizeCreateIndex(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.CreateIndexStatement) + ret := plan.NewCreateIndexPlan(stmt) + vt, ok := o.rule.VTable(stmt.Table.Suffix()) + + // table shard + if !ok { + return ret, nil + } + + // sharding + shards := rule.DatabaseTables{} + topology := vt.Topology() + topology.Each(func(dbIdx, tbIdx int) bool { + if d, t, ok := topology.Render(dbIdx, tbIdx); ok { + shards[d] = append(shards[d], t) + } + return true + }) + ret.SetShard(shards) + return ret, nil +} diff --git a/pkg/runtime/optimize/delete.go b/pkg/runtime/optimize/delete.go new file mode 100644 index 00000000..5bd53fdc --- /dev/null +++ b/pkg/runtime/optimize/delete.go @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeDelete, optimizeDelete) +} + +func optimizeDelete(ctx context.Context, o *optimizer) (proto.Plan, error) { + var ( + stmt = o.stmt.(*ast.DeleteStatement) + ) + + shards, err := o.computeShards(stmt.Table, stmt.Where, o.args) + if err != nil { + return nil, errors.Wrap(err, "failed to optimize DELETE statement") + } + + // TODO: delete from a child sharding-table directly + + if shards == nil { + return plan.Transparent(stmt, o.args), nil + } + + ret := plan.NewSimpleDeletePlan(stmt) + ret.BindArgs(o.args) + ret.SetShards(shards) + + return ret, nil +} diff --git a/pkg/runtime/optimize/describe.go b/pkg/runtime/optimize/describe.go new file mode 100644 index 00000000..1656ba5e --- /dev/null +++ b/pkg/runtime/optimize/describe.go @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeDescribe, optimizeDescribeStatement) +} + +func optimizeDescribeStatement(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.DescribeStatement) + vts := o.rule.VTables() + vtName := []string(stmt.Table)[0] + ret := plan.NewDescribePlan(stmt) + ret.BindArgs(o.args) + + if vTable, ok := vts[vtName]; ok { + shards := rule.DatabaseTables{} + // compute all tables + topology := vTable.Topology() + topology.Each(func(dbIdx, tbIdx int) bool { + if d, t, ok := topology.Render(dbIdx, tbIdx); ok { + shards[d] = append(shards[d], t) + } + return true + }) + dbName, tblName := shards.Smallest() + ret.Database = dbName + ret.Table = tblName + ret.Column = stmt.Column + } + + return ret, nil +} diff --git a/pkg/runtime/optimize/drop_index.go b/pkg/runtime/optimize/drop_index.go new file mode 100644 index 00000000..19676451 --- /dev/null +++ b/pkg/runtime/optimize/drop_index.go @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeDropIndex, optimizeDropIndex) +} + +func optimizeDropIndex(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.DropIndexStatement) + //table shard + + shard, err := o.computeShards(stmt.Table, nil, o.args) + if err != nil { + return nil, err + } + if len(shard) == 0 { + return plan.Transparent(stmt, o.args), nil + } + + shardPlan := plan.NewDropIndexPlan(stmt) + shardPlan.SetShard(shard) + shardPlan.BindArgs(o.args) + return shardPlan, nil +} diff --git a/pkg/runtime/optimize/drop_table.go b/pkg/runtime/optimize/drop_table.go new file mode 100644 index 00000000..27f55779 --- /dev/null +++ b/pkg/runtime/optimize/drop_table.go @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeDropTable, optimizeDropTable) +} + +func optimizeDropTable(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.DropTableStatement) + //table shard + var shards []rule.DatabaseTables + //tables not shard + noShardStmt := ast.NewDropTableStatement() + for _, table := range stmt.Tables { + shard, err := o.computeShards(*table, nil, o.args) + if err != nil { + return nil, err + } + if shard == nil { + noShardStmt.Tables = append(noShardStmt.Tables, table) + continue + } + shards = append(shards, shard) + } + + shardPlan := plan.NewDropTablePlan(stmt) + shardPlan.BindArgs(o.args) + shardPlan.SetShards(shards) + + if len(noShardStmt.Tables) == 0 { + return shardPlan, nil + } + + noShardPlan := plan.Transparent(noShardStmt, o.args) + + return &plan.UnionPlan{ + Plans: []proto.Plan{ + noShardPlan, shardPlan, + }, + }, nil +} diff --git a/pkg/runtime/optimize/drop_trigger.go b/pkg/runtime/optimize/drop_trigger.go new file mode 100644 index 00000000..7af7db37 --- /dev/null +++ b/pkg/runtime/optimize/drop_trigger.go @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeDropTrigger, optimizeTrigger) +} + +func optimizeTrigger(_ context.Context, o *optimizer) (proto.Plan, error) { + shards := rule.DatabaseTables{} + for _, table := range o.rule.VTables() { + topology := table.Topology() + topology.Each(func(dbIdx, tbIdx int) bool { + if d, t, ok := topology.Render(dbIdx, tbIdx); ok { + shards[d] = append(shards[d], t) + } + return true + }) + + break + } + + ret := &plan.DropTriggerPlan{Stmt: o.stmt.(*ast.DropTriggerStatement), Shards: shards} + ret.BindArgs(o.args) + return ret, nil +} diff --git a/pkg/runtime/optimize/insert.go b/pkg/runtime/optimize/insert.go new file mode 100644 index 00000000..a476a21a --- /dev/null +++ b/pkg/runtime/optimize/insert.go @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/cmp" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeInsert, optimizeInsert) + registerOptimizeHandler(ast.SQLTypeInsertSelect, optimizeInsertSelect) +} + +func optimizeInsert(ctx context.Context, o *optimizer) (proto.Plan, error) { + ret := plan.NewSimpleInsertPlan() + ret.BindArgs(o.args) + + var ( + stmt = o.stmt.(*ast.InsertStatement) + vt *rule.VTable + ok bool + ) + + if vt, ok = o.rule.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table + ret.Put("", stmt) + return ret, nil + } + + // TODO: handle multiple shard keys. + + bingo := -1 + // check existing shard columns + for i, col := range stmt.Columns() { + if _, _, ok = vt.GetShardMetadata(col); ok { + bingo = i + break + } + } + + if bingo < 0 { + return nil, errors.Wrap(errNoShardKeyFound, "failed to insert") + } + + //check on duplicated key update + for _, upd := range stmt.DuplicatedUpdates() { + if upd.Column.Suffix() == stmt.Columns()[bingo] { + return nil, errors.New("do not support update sharding key") + } + } + + var ( + sharder = (*Sharder)(o.rule) + left = ast.ColumnNameExpressionAtom(make([]string, 1)) + filter = &ast.PredicateExpressionNode{ + P: &ast.BinaryComparisonPredicateNode{ + Left: &ast.AtomPredicateNode{ + A: left, + }, + Op: cmp.Ceq, + }, + } + slots = make(map[string]map[string][]int) // (db,table,valuesIndex) + ) + + // reset filter + resetFilter := func(column string, value ast.ExpressionNode) { + left[0] = column + filter.P.(*ast.BinaryComparisonPredicateNode).Right = value.(*ast.PredicateExpressionNode).P + } + + for i, values := range stmt.Values() { + value := values[bingo] + resetFilter(stmt.Columns()[bingo], value) + + shards, _, err := sharder.Shard(stmt.Table(), filter, o.args...) + + if err != nil { + return nil, errors.WithStack(err) + } + + if shards.Len() != 1 { + return nil, errors.Wrap(errNoShardKeyFound, "failed to insert") + } + + var ( + db string + table string + ) + + for k, v := range shards { + db = k + table = v[0] + break + } + + if _, ok = slots[db]; !ok { + slots[db] = make(map[string][]int) + } + slots[db][table] = append(slots[db][table], i) + } + + for db, slot := range slots { + for table, indexes := range slot { + // clone insert stmt without values + newborn := ast.NewInsertStatement(ast.TableName{table}, stmt.Columns()) + newborn.SetFlag(stmt.Flag()) + newborn.SetDuplicatedUpdates(stmt.DuplicatedUpdates()) + + // collect values with same table + values := make([][]ast.ExpressionNode, 0, len(indexes)) + for _, i := range indexes { + values = append(values, stmt.Values()[i]) + } + newborn.SetValues(values) + + rewriteInsertStatement(ctx, o, newborn, db, table) + ret.Put(db, newborn) + } + } + + return ret, nil +} + +func optimizeInsertSelect(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.InsertSelectStatement) + + ret := plan.NewInsertSelectPlan() + + ret.BindArgs(o.args) + + if _, ok := o.rule.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table + ret.Batch[""] = stmt + return ret, nil + } + + // TODO: handle shard keys. + + return nil, errors.New("not support insert-select into sharding table") +} + +func rewriteInsertStatement(ctx context.Context, o *optimizer, stmt *ast.InsertStatement, db, tb string) error { + metaData := o.schemaLoader.Load(ctx, o.vconn, db, []string{tb})[tb] + if metaData == nil || len(metaData.ColumnNames) == 0 { + return errors.Errorf("can not get metadata for db:%s and table:%s", db, tb) + } + + if len(metaData.ColumnNames) == len(stmt.Columns()) { + // User had explicitly specified every value + return nil + } + columnsMetadata := metaData.Columns + + for _, colName := range stmt.Columns() { + if columnsMetadata[colName].PrimaryKey && columnsMetadata[colName].Generated { + // User had explicitly specified auto-generated primary key column + return nil + } + } + + pkColName := "" + for name, column := range columnsMetadata { + if column.PrimaryKey && column.Generated { + pkColName = name + break + } + } + if len(pkColName) < 1 { + // There's no auto-generated primary key column + return nil + } + + // TODO rewrite columns and add distributed primary key + //stmt.SetColumns(append(stmt.Columns(), pkColName)) + // append value of distributed primary key + //newValues := stmt.Values() + //for _, newValue := range newValues { + // newValue = append(newValue, ) + //} + return nil +} diff --git a/pkg/runtime/optimize/optimizer.go b/pkg/runtime/optimize/optimizer.go index 89bc9d42..0db5fcdf 100644 --- a/pkg/runtime/optimize/optimizer.go +++ b/pkg/runtime/optimize/optimizer.go @@ -19,1021 +19,114 @@ package optimize import ( "context" - stdErrors "errors" - "strings" + "errors" ) import ( "github.com/arana-db/parser/ast" - "github.com/pkg/errors" + perrors "github.com/pkg/errors" + + "go.opentelemetry.io/otel" ) import ( - "github.com/arana-db/arana/pkg/dataset" - "github.com/arana-db/arana/pkg/merge/aggregator" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/proto/rule" - "github.com/arana-db/arana/pkg/proto/schema_manager" - "github.com/arana-db/arana/pkg/runtime" rast "github.com/arana-db/arana/pkg/runtime/ast" - "github.com/arana-db/arana/pkg/runtime/cmp" rcontext "github.com/arana-db/arana/pkg/runtime/context" - "github.com/arana-db/arana/pkg/runtime/namespace" - "github.com/arana-db/arana/pkg/runtime/plan" - "github.com/arana-db/arana/pkg/security" - "github.com/arana-db/arana/pkg/transformer" "github.com/arana-db/arana/pkg/util/log" ) var _ proto.Optimizer = (*optimizer)(nil) +var Tracer = otel.Tracer("optimize") + // errors group var ( - errNoRuleFound = stdErrors.New("no rule found") - errDenyFullScan = stdErrors.New("the full-scan query is not allowed") - errNoShardKeyFound = stdErrors.New("no shard key found") + errNoRuleFound = errors.New("optimize: no rule found") + errDenyFullScan = errors.New("optimize: the full-scan query is not allowed") + errNoShardKeyFound = errors.New("optimize: no shard key found") ) // IsNoShardKeyFoundErr returns true if target error is caused by NO-SHARD-KEY-FOUND func IsNoShardKeyFoundErr(err error) bool { - return errors.Is(err, errNoShardKeyFound) + return perrors.Is(err, errNoShardKeyFound) } // IsNoRuleFoundErr returns true if target error is caused by NO-RULE-FOUND. func IsNoRuleFoundErr(err error) bool { - return errors.Is(err, errNoRuleFound) + return perrors.Is(err, errNoRuleFound) } // IsDenyFullScanErr returns true if target error is caused by DENY-FULL-SCAN. func IsDenyFullScanErr(err error) bool { - return errors.Is(err, errDenyFullScan) -} - -func GetOptimizer() proto.Optimizer { - return optimizer{ - schemaLoader: schema_manager.NewSimpleSchemaLoader(), - } + return perrors.Is(err, errDenyFullScan) } type optimizer struct { + rule *rule.Rule + hints []*hint.Hint + vconn proto.VConn + stmt rast.Statement + args []interface{} schemaLoader proto.SchemaLoader } -func (o *optimizer) SetSchemaLoader(schemaLoader proto.SchemaLoader) { - o.schemaLoader = schemaLoader -} - -func (o *optimizer) SchemaLoader() proto.SchemaLoader { - return o.schemaLoader -} - -func (o optimizer) Optimize(ctx context.Context, conn proto.VConn, stmt ast.StmtNode, args ...interface{}) (plan proto.Plan, err error) { - ctx, span := runtime.Tracer.Start(ctx, "Optimize") - defer func() { - span.End() - if rec := recover(); rec != nil { - err = errors.Errorf("cannot analyze sql %s", rcontext.SQL(ctx)) - log.Errorf("optimize panic: sql=%s, rec=%v", rcontext.SQL(ctx), rec) - } - }() - - var rstmt rast.Statement - if rstmt, err = rast.FromStmtNode(stmt); err != nil { - return nil, errors.Wrap(err, "optimize failed") - } - return o.doOptimize(ctx, conn, rstmt, args...) -} - -func (o optimizer) doOptimize(ctx context.Context, conn proto.VConn, stmt rast.Statement, args ...interface{}) (proto.Plan, error) { - switch t := stmt.(type) { - case *rast.ShowDatabases: - return o.optimizeShowDatabases(ctx, t, args) - case *rast.SelectStatement: - return o.optimizeSelect(ctx, conn, t, args) - case *rast.InsertStatement: - return o.optimizeInsert(ctx, conn, t, args) - case *rast.InsertSelectStatement: - return o.optimizeInsertSelect(ctx, conn, t, args) - case *rast.DeleteStatement: - return o.optimizeDelete(ctx, t, args) - case *rast.UpdateStatement: - return o.optimizeUpdate(ctx, conn, t, args) - case *rast.ShowOpenTables: - return o.optimizeShowOpenTables(ctx, t, args) - case *rast.ShowTables: - return o.optimizeShowTables(ctx, t, args) - case *rast.ShowIndex: - return o.optimizeShowIndex(ctx, t, args) - case *rast.ShowColumns: - return o.optimizeShowColumns(ctx, t, args) - case *rast.ShowCreate: - return o.optimizeShowCreate(ctx, t, args) - case *rast.TruncateStatement: - return o.optimizeTruncate(ctx, t, args) - case *rast.DropTableStatement: - return o.optimizeDropTable(ctx, t, args) - case *rast.ShowVariables: - return o.optimizeShowVariables(ctx, t, args) - case *rast.DescribeStatement: - return o.optimizeDescribeStatement(ctx, t, args) - case *rast.AlterTableStatement: - return o.optimizeAlterTable(ctx, t, args) - case *rast.DropIndexStatement: - return o.optimizeDropIndex(ctx, t, args) - case *rast.CreateIndexStatement: - return o.optimizeCreateIndex(ctx, t, args) - case *rast.DropTriggerStatement: - return o.optimizeTrigger(ctx, t, args) - } - - //TODO implement all statements - panic("implement me") -} - -const ( - _bypass uint32 = 1 << iota - _supported -) - -func (o optimizer) optimizeDropIndex(ctx context.Context, stmt *rast.DropIndexStatement, args []interface{}) (proto.Plan, error) { - ru := rcontext.Rule(ctx) - //table shard - - shard, err := o.computeShards(ru, stmt.Table, nil, args) - if err != nil { - return nil, err - } - if len(shard) == 0 { - return plan.Transparent(stmt, args), nil - } - - shardPlan := plan.NewDropIndexPlan(stmt) - shardPlan.SetShard(shard) - shardPlan.BindArgs(args) - return shardPlan, nil -} - -func (o optimizer) optimizeCreateIndex(ctx context.Context, stmt *rast.CreateIndexStatement, args []interface{}) (proto.Plan, error) { - +func NewOptimizer(vconn proto.VConn, schemaer proto.SchemaLoader, rule *rule.Rule, hints []*hint.Hint, stmt ast.StmtNode, args []interface{}) (proto.Optimizer, error) { var ( - ret = plan.NewCreateIndexPlan(stmt) - ru = rcontext.Rule(ctx) - vt *rule.VTable + rstmt rast.Statement + err error ) - - //table shard - _, ok := ru.VTable(stmt.Table.String()) - if !ok { - return ret, nil - } - - // sharding - shards := rule.DatabaseTables{} - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - ret.SetShard(shards) - return ret, nil -} - -func (o optimizer) optimizeAlterTable(ctx context.Context, stmt *rast.AlterTableStatement, args []interface{}) (proto.Plan, error) { - var ( - ret = plan.NewAlterTablePlan(stmt) - ru = rcontext.Rule(ctx) - table = stmt.Table - vt *rule.VTable - ok bool - ) - ret.BindArgs(args) - - // non-sharding update - if vt, ok = ru.VTable(table.Suffix()); !ok { - return ret, nil - } - - //TODO alter table table or column to new name , should update sharding info - - // exit if full-scan is disabled - if !vt.AllowFullScan() { - return nil, errDenyFullScan - } - - // sharding - shards := rule.DatabaseTables{} - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - ret.Shards = shards - return ret, nil -} - -func (o optimizer) optimizeDropTable(ctx context.Context, stmt *rast.DropTableStatement, args []interface{}) (proto.Plan, error) { - ru := rcontext.Rule(ctx) - //table shard - var shards []rule.DatabaseTables - //tables not shard - noShardStmt := rast.NewDropTableStatement() - for _, table := range stmt.Tables { - shard, err := o.computeShards(ru, *table, nil, args) - if err != nil { - return nil, err - } - if shard == nil { - noShardStmt.Tables = append(noShardStmt.Tables, table) - continue - } - shards = append(shards, shard) - } - - shardPlan := plan.NewDropTablePlan(stmt) - shardPlan.BindArgs(args) - shardPlan.SetShards(shards) - - if len(noShardStmt.Tables) == 0 { - return shardPlan, nil + if rstmt, err = rast.FromStmtNode(stmt); err != nil { + return nil, perrors.Wrap(err, "optimize failed") } - noShardPlan := plan.Transparent(noShardStmt, args) - - return &plan.UnionPlan{ - Plans: []proto.Plan{ - noShardPlan, shardPlan, - }, + return &optimizer{ + rule: rule, + hints: hints, + vconn: vconn, + stmt: rstmt, + args: args, + schemaLoader: schemaer, }, nil } -func (o optimizer) getSelectFlag(ctx context.Context, stmt *rast.SelectStatement) (flag uint32) { - switch len(stmt.From) { - case 1: - from := stmt.From[0] - tn := from.TableName() - - if tn == nil { // only FROM table supported now - return - } - - flag |= _supported - - if len(tn) > 1 { - switch strings.ToLower(tn.Prefix()) { - case "mysql", "information_schema": - flag |= _bypass - return - } - } - if !rcontext.Rule(ctx).Has(tn.Suffix()) { - flag |= _bypass - } - case 0: - flag |= _bypass - flag |= _supported - } - return -} - -func (o optimizer) optimizeShowDatabases(_ context.Context, stmt *rast.ShowDatabases, args []interface{}) (proto.Plan, error) { - ret := &plan.ShowDatabasesPlan{Stmt: stmt} - ret.BindArgs(args) - return ret, nil -} - -func (o optimizer) overwriteLimit(stmt *rast.SelectStatement, args *[]interface{}) (originOffset, overwriteLimit int64) { - if stmt == nil || stmt.Limit == nil { - return 0, 0 - } - - offset := stmt.Limit.Offset() - limit := stmt.Limit.Limit() - - // SELECT * FROM student where uid = ? limit ? offset ? - var offsetIndex int64 - var limitIndex int64 - - if stmt.Limit.IsOffsetVar() { - offsetIndex = offset - offset = (*args)[offsetIndex].(int64) - - if !stmt.Limit.IsLimitVar() { - limit = stmt.Limit.Limit() - *args = append(*args, limit) - limitIndex = int64(len(*args) - 1) - } - } - originOffset = offset - - if stmt.Limit.IsLimitVar() { - limitIndex = limit - limit = (*args)[limitIndex].(int64) - - if !stmt.Limit.IsOffsetVar() { - *args = append(*args, int64(0)) - offsetIndex = int64(len(*args) - 1) - } - } - - if stmt.Limit.IsLimitVar() || stmt.Limit.IsOffsetVar() { - if !stmt.Limit.IsLimitVar() { - stmt.Limit.SetLimitVar() - stmt.Limit.SetLimit(limitIndex) - } - if !stmt.Limit.IsOffsetVar() { - stmt.Limit.SetOffsetVar() - stmt.Limit.SetOffset(offsetIndex) - } - - newLimitVar := limit + offset - overwriteLimit = newLimitVar - (*args)[limitIndex] = newLimitVar - (*args)[offsetIndex] = int64(0) - return - } - - stmt.Limit.SetOffset(0) - stmt.Limit.SetLimit(offset + limit) - overwriteLimit = offset + limit - return -} - -func (o optimizer) optimizeOrderBy(stmt *rast.SelectStatement) []dataset.OrderByItem { - if stmt == nil || stmt.OrderBy == nil { - return nil - } - result := make([]dataset.OrderByItem, 0, len(stmt.OrderBy)) - for _, node := range stmt.OrderBy { - column, _ := node.Expr.(rast.ColumnNameExpressionAtom) - item := dataset.OrderByItem{ - Column: column[0], - Desc: node.Desc, - } - result = append(result, item) - } - return result -} - -func (o optimizer) optimizeSelect(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement, args []interface{}) (proto.Plan, error) { - var ru *rule.Rule - if ru = rcontext.Rule(ctx); ru == nil { - return nil, errors.WithStack(errNoRuleFound) - } - - // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be - // `select * from student offset 0 limit 100+5` - originOffset, overwriteLimit := o.overwriteLimit(stmt, &args) - if stmt.HasJoin() { - return o.optimizeJoin(ctx, conn, stmt, args) - } - flag := o.getSelectFlag(ctx, stmt) - if flag&_supported == 0 { - return nil, errors.Errorf("unsupported sql: %s", rcontext.SQL(ctx)) - } - - if flag&_bypass != 0 { - if len(stmt.From) > 0 { - err := o.rewriteSelectStatement(ctx, conn, stmt, rcontext.DBGroup(ctx), stmt.From[0].TableName().Suffix()) - if err != nil { - return nil, err - } - } - ret := &plan.SimpleQueryPlan{Stmt: stmt} - ret.BindArgs(args) - return ret, nil - } - - var ( - shards rule.DatabaseTables - fullScan bool - err error - vt = ru.MustVTable(stmt.From[0].TableName().Suffix()) - ) - - if shards, fullScan, err = (*Sharder)(ru).Shard(stmt.From[0].TableName(), stmt.Where, args...); err != nil { - return nil, errors.Wrap(err, "calculate shards failed") - } - - log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) - - // return error if full-scan is disabled - if fullScan && !vt.AllowFullScan() { - return nil, errors.WithStack(errDenyFullScan) - } - - toSingle := func(db, tbl string) (proto.Plan, error) { - if err := o.rewriteSelectStatement(ctx, conn, stmt, db, tbl); err != nil { - return nil, err - } - ret := &plan.SimpleQueryPlan{ - Stmt: stmt, - Database: db, - Tables: []string{tbl}, - } - ret.BindArgs(args) - - return ret, nil - } - - // Go through first table if no shards matched. - // For example: - // SELECT ... FROM xxx WHERE a > 8 and a < 4 - if shards.IsEmpty() { - var ( - db0, tbl0 string - ok bool - ) - if db0, tbl0, ok = vt.Topology().Render(0, 0); !ok { - return nil, errors.Errorf("cannot compute minimal topology from '%s'", stmt.From[0].TableName().Suffix()) - } - - return toSingle(db0, tbl0) - } - - // Handle single shard - if shards.Len() == 1 { - var db, tbl string - for k, v := range shards { - db = k - tbl = v[0] - } - return toSingle(db, tbl) - } - - // Handle multiple shards - - if shards.IsFullScan() { // expand all shards if all shards matched - // init shards - shards = rule.DatabaseTables{} - // compute all tables - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - } - - plans := make([]proto.Plan, 0, len(shards)) - for k, v := range shards { - next := &plan.SimpleQueryPlan{ - Database: k, - Tables: v, - Stmt: stmt, - } - next.BindArgs(args) - plans = append(plans, next) - } - - if len(plans) > 0 { - tempPlan := plans[0].(*plan.SimpleQueryPlan) - if err = o.rewriteSelectStatement(ctx, conn, stmt, tempPlan.Database, tempPlan.Tables[0]); err != nil { - return nil, err - } - } - - var tmpPlan proto.Plan - tmpPlan = &plan.UnionPlan{ - Plans: plans, - } - - if stmt.Limit != nil { - tmpPlan = &plan.LimitPlan{ - ParentPlan: tmpPlan, - OriginOffset: originOffset, - OverwriteLimit: overwriteLimit, - } - } - - orderByItems := o.optimizeOrderBy(stmt) - - if stmt.OrderBy != nil { - tmpPlan = &plan.OrderPlan{ - ParentPlan: tmpPlan, - OrderByItems: orderByItems, - } - } - - convertOrderByItems := func(origins []*rast.OrderByItem) []dataset.OrderByItem { - var result = make([]dataset.OrderByItem, 0, len(origins)) - for _, origin := range origins { - var columnName string - if cn, ok := origin.Expr.(rast.ColumnNameExpressionAtom); ok { - columnName = cn.Suffix() - } - result = append(result, dataset.OrderByItem{ - Column: columnName, - Desc: origin.Desc, - }) - } - return result - } - if stmt.GroupBy != nil { - return &plan.GroupPlan{ - Plan: tmpPlan, - AggItems: aggregator.LoadAggs(stmt.Select), - OrderByItems: convertOrderByItems(stmt.OrderBy), - }, nil - } else { - // TODO: refactor groupby/orderby/aggregate plan to a unified plan - return &plan.AggregatePlan{ - Plan: tmpPlan, - Combiner: transformer.NewCombinerManager(), - AggrLoader: transformer.LoadAggrs(stmt.Select), - }, nil - } -} - -//optimizeJoin ony support a join b in one db -func (o optimizer) optimizeJoin(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement, args []interface{}) (proto.Plan, error) { - - var ru *rule.Rule - if ru = rcontext.Rule(ctx); ru == nil { - return nil, errors.WithStack(errNoRuleFound) - } - - join := stmt.From[0].Source().(*rast.JoinNode) - - compute := func(tableSource *rast.TableSourceNode) (database, alias string, shardList []string, err error) { - table := tableSource.TableName() - if table == nil { - err = errors.New("must table, not statement or join node") - return - } - alias = tableSource.Alias() - database = table.Prefix() - - shards, err := o.computeShards(ru, table, nil, args) - if err != nil { - return - } - //table no shard - if shards == nil { - shardList = append(shardList, table.Suffix()) - return - } - //table shard more than one db - if len(shards) > 1 { - err = errors.New("not support more than one db") - return - } - - for k, v := range shards { - database = k - shardList = v - } - - if alias == "" { - alias = table.Suffix() - } - - return - } - - dbLeft, aliasLeft, shardLeft, err := compute(join.Left) - if err != nil { - return nil, err - } - dbRight, aliasRight, shardRight, err := compute(join.Right) - - if err != nil { - return nil, err - } - - if dbLeft != "" && dbRight != "" && dbLeft != dbRight { - return nil, errors.New("not support more than one db") - } - - joinPan := &plan.SimpleJoinPlan{ - Left: &plan.JoinTable{ - Tables: shardLeft, - Alias: aliasLeft, - }, - Join: join, - Right: &plan.JoinTable{ - Tables: shardRight, - Alias: aliasRight, - }, - Stmt: stmt, - } - joinPan.BindArgs(args) - - return joinPan, nil -} - -func (o optimizer) optimizeUpdate(ctx context.Context, conn proto.VConn, stmt *rast.UpdateStatement, args []interface{}) (proto.Plan, error) { - var ( - ru = rcontext.Rule(ctx) - table = stmt.Table - vt *rule.VTable - ok bool - ) - - // non-sharding update - if vt, ok = ru.VTable(table.Suffix()); !ok { - ret := plan.NewUpdatePlan(stmt) - ret.BindArgs(args) - return ret, nil - } - - //check update sharding key - for _, element := range stmt.Updated { - if _, _, ok := vt.GetShardMetadata(element.Column.Suffix()); ok { - return nil, errors.New("do not support update sharding key") - } - } +type optimizeHandler func(ctx context.Context, o *optimizer) (proto.Plan, error) - var ( - shards rule.DatabaseTables - fullScan = true - err error - ) - - // compute shards - if where := stmt.Where; where != nil { - sharder := (*Sharder)(ru) - if shards, fullScan, err = sharder.Shard(table, where, args...); err != nil { - return nil, errors.Wrap(err, "failed to update") - } - } - - // exit if full-scan is disabled - if fullScan && !vt.AllowFullScan() { - return nil, errDenyFullScan - } - - // must be empty shards (eg: update xxx set ... where 1 = 2 and uid = 1) - if shards.IsEmpty() { - return plan.AlwaysEmptyExecPlan{}, nil - } - - // compute all sharding tables - if shards.IsFullScan() { - // init shards - shards = rule.DatabaseTables{} - // compute all tables - topology := vt.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - } - - ret := plan.NewUpdatePlan(stmt) - ret.BindArgs(args) - ret.SetShards(shards) - - return ret, nil -} - -func (o optimizer) optimizeInsert(ctx context.Context, conn proto.VConn, stmt *rast.InsertStatement, args []interface{}) (proto.Plan, error) { - var ( - ru = rcontext.Rule(ctx) - ret = plan.NewSimpleInsertPlan() - ) - - ret.BindArgs(args) - - var ( - vt *rule.VTable - ok bool - ) - - if vt, ok = ru.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table - ret.Put("", stmt) - return ret, nil - } - - // TODO: handle multiple shard keys. - - bingo := -1 - // check existing shard columns - for i, col := range stmt.Columns() { - if _, _, ok = vt.GetShardMetadata(col); ok { - bingo = i - break - } - } - - if bingo < 0 { - return nil, errors.Wrap(errNoShardKeyFound, "failed to insert") - } - - //check on duplicated key update - for _, upd := range stmt.DuplicatedUpdates() { - if upd.Column.Suffix() == stmt.Columns()[bingo] { - return nil, errors.New("do not support update sharding key") - } - } - - var ( - sharder = (*Sharder)(ru) - left = rast.ColumnNameExpressionAtom(make([]string, 1)) - filter = &rast.PredicateExpressionNode{ - P: &rast.BinaryComparisonPredicateNode{ - Left: &rast.AtomPredicateNode{ - A: left, - }, - Op: cmp.Ceq, - }, - } - slots = make(map[string]map[string][]int) // (db,table,valuesIndex) - ) - - // reset filter - resetFilter := func(column string, value rast.ExpressionNode) { - left[0] = column - filter.P.(*rast.BinaryComparisonPredicateNode).Right = value.(*rast.PredicateExpressionNode).P - } - - for i, values := range stmt.Values() { - value := values[bingo] - resetFilter(stmt.Columns()[bingo], value) - - shards, _, err := sharder.Shard(stmt.Table(), filter, args...) - - if err != nil { - return nil, errors.WithStack(err) - } - - if shards.Len() != 1 { - return nil, errors.Wrap(errNoShardKeyFound, "failed to insert") - } - - var ( - db string - table string - ) - - for k, v := range shards { - db = k - table = v[0] - break - } - - if _, ok = slots[db]; !ok { - slots[db] = make(map[string][]int) - } - slots[db][table] = append(slots[db][table], i) - } - - for db, slot := range slots { - for table, indexes := range slot { - // clone insert stmt without values - newborn := rast.NewInsertStatement(rast.TableName{table}, stmt.Columns()) - newborn.SetFlag(stmt.Flag()) - newborn.SetDuplicatedUpdates(stmt.DuplicatedUpdates()) - - // collect values with same table - values := make([][]rast.ExpressionNode, 0, len(indexes)) - for _, i := range indexes { - values = append(values, stmt.Values()[i]) - } - newborn.SetValues(values) - - o.rewriteInsertStatement(ctx, conn, newborn, db, table) - ret.Put(db, newborn) - } - } - - return ret, nil -} - -func (o optimizer) optimizeInsertSelect(ctx context.Context, conn proto.VConn, stmt *rast.InsertSelectStatement, args []interface{}) (proto.Plan, error) { - var ( - ru = rcontext.Rule(ctx) - ret = plan.NewInsertSelectPlan() - ) - - ret.BindArgs(args) - - if _, ok := ru.VTable(stmt.Table().Suffix()); !ok { // insert into non-sharding table - ret.Batch[""] = stmt - return ret, nil - } - - // TODO: handle shard keys. +var ( + _handlers = make(map[rast.SQLType]optimizeHandler) +) - return nil, errors.New("not support insert-select into sharding table") +func registerOptimizeHandler(t rast.SQLType, h optimizeHandler) { + _handlers[t] = h } -func (o optimizer) optimizeDelete(ctx context.Context, stmt *rast.DeleteStatement, args []interface{}) (proto.Plan, error) { - ru := rcontext.Rule(ctx) - shards, err := o.computeShards(ru, stmt.Table, stmt.Where, args) - if err != nil { - return nil, errors.Wrap(err, "failed to optimize DELETE statement") - } - - // TODO: delete from a child sharding-table directly - - if shards == nil { - return plan.Transparent(stmt, args), nil - } - - ret := plan.NewSimpleDeletePlan(stmt) - ret.BindArgs(args) - ret.SetShards(shards) - - return ret, nil +func init() { + registerOptimizeHandler(rast.SQLTypeAlterTable, optimizeAlterTable) } -func (o optimizer) optimizeShowOpenTables(ctx context.Context, stmt *rast.ShowOpenTables, args []interface{}) (proto.Plan, error) { - var invertedIndex map[string]string - for logicalTable, v := range rcontext.Rule(ctx).VTables() { - t := v.Topology() - t.Each(func(x, y int) bool { - if _, phyTable, ok := t.Render(x, y); ok { - if invertedIndex == nil { - invertedIndex = make(map[string]string) - } - invertedIndex[phyTable] = logicalTable - } - return true - }) - } - - clusters := security.DefaultTenantManager().GetClusters(rcontext.Tenant(ctx)) - plans := make([]proto.Plan, 0, len(clusters)) - for _, cluster := range clusters { - ns := namespace.Load(cluster) - // 配置里原子库 都需要执行一次 - groups := ns.DBGroups() - for i := 0; i < len(groups); i++ { - ret := plan.NewShowOpenTablesPlan(stmt) - ret.BindArgs(args) - ret.SetInvertedShards(invertedIndex) - ret.SetDatabase(groups[i]) - plans = append(plans, ret) +func (o *optimizer) Optimize(ctx context.Context) (plan proto.Plan, err error) { + ctx, span := Tracer.Start(ctx, "Optimize") + defer func() { + span.End() + if rec := recover(); rec != nil { + err = perrors.Errorf("cannot analyze sql %s", rcontext.SQL(ctx)) + log.Errorf("optimize panic: sql=%s, rec=%v", rcontext.SQL(ctx), rec) } - } - - unionPlan := &plan.UnionPlan{ - Plans: plans, - } - - aggregate := &plan.AggregatePlan{ - Plan: unionPlan, - Combiner: transformer.NewCombinerManager(), - AggrLoader: transformer.LoadAggrs(nil), - } - - return aggregate, nil -} - -func (o optimizer) optimizeShowTables(ctx context.Context, stmt *rast.ShowTables, args []interface{}) (proto.Plan, error) { - var invertedIndex map[string]string - for logicalTable, v := range rcontext.Rule(ctx).VTables() { - t := v.Topology() - t.Each(func(x, y int) bool { - if _, phyTable, ok := t.Render(x, y); ok { - if invertedIndex == nil { - invertedIndex = make(map[string]string) - } - invertedIndex[phyTable] = logicalTable - } - return true - }) - } - - ret := plan.NewShowTablesPlan(stmt) - ret.BindArgs(args) - ret.SetInvertedShards(invertedIndex) - return ret, nil -} - -func (o optimizer) optimizeShowIndex(ctx context.Context, stmt *rast.ShowIndex, args []interface{}) (proto.Plan, error) { - var ru *rule.Rule - if ru = rcontext.Rule(ctx); ru == nil { - return nil, errors.WithStack(errNoRuleFound) - } - - ret := &plan.ShowIndexPlan{Stmt: stmt} - ret.BindArgs(args) + }() - vt, ok := ru.VTable(stmt.TableName.Suffix()) + h, ok := _handlers[o.stmt.Mode()] if !ok { - return ret, nil - } - - shards := rule.DatabaseTables{} - - topology := vt.Topology() - if d, t, ok := topology.Render(0, 0); ok { - shards[d] = append(shards[d], t) - } - ret.Shards = shards - return ret, nil -} - -func (o optimizer) optimizeShowColumns(ctx context.Context, stmt *rast.ShowColumns, args []interface{}) (proto.Plan, error) { - vts := rcontext.Rule(ctx).VTables() - vtName := []string(stmt.TableName)[0] - ret := &plan.ShowColumnsPlan{Stmt: stmt} - ret.BindArgs(args) - - if vTable, ok := vts[vtName]; ok { - shards := rule.DatabaseTables{} - // compute all tables - topology := vTable.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - _, tblName := shards.Smallest() - ret.Table = tblName - } - - return ret, nil -} - -func (o optimizer) optimizeShowCreate(ctx context.Context, stmt *rast.ShowCreate, args []interface{}) (proto.Plan, error) { - if stmt.Type() != rast.ShowCreateTypeTable { - return nil, errors.Errorf("not support SHOW CREATE %s", stmt.Type().String()) - } - - var ( - ret = plan.NewShowCreatePlan(stmt) - ru = rcontext.Rule(ctx) - table = stmt.Target() - ) - ret.BindArgs(args) - - if vt, ok := ru.VTable(table); ok { - // sharding - topology := vt.Topology() - if d, t, ok := topology.Render(0, 0); ok { - ret.Database = d - ret.Table = t - } else { - return nil, errors.Errorf("failed to render table:%s ", table) - } - } else { - ret.Table = table - } - - return ret, nil -} - -func (o optimizer) optimizeTruncate(ctx context.Context, stmt *rast.TruncateStatement, args []interface{}) (proto.Plan, error) { - ru := rcontext.Rule(ctx) - shards, err := o.computeShards(ru, stmt.Table, nil, args) - if err != nil { - return nil, errors.Wrap(err, "failed to optimize TRUNCATE statement") - } - - if shards == nil { - return plan.Transparent(stmt, args), nil - } - - ret := plan.NewTruncatePlan(stmt) - ret.BindArgs(args) - ret.SetShards(shards) - - return ret, nil -} - -func (o optimizer) optimizeShowVariables(ctx context.Context, stmt *rast.ShowVariables, args []interface{}) (proto.Plan, error) { - ret := plan.NewShowVariablesPlan(stmt) - ret.BindArgs(args) - return ret, nil -} - -func (o optimizer) optimizeDescribeStatement(ctx context.Context, stmt *rast.DescribeStatement, args []interface{}) (proto.Plan, error) { - vts := rcontext.Rule(ctx).VTables() - vtName := []string(stmt.Table)[0] - ret := plan.NewDescribePlan(stmt) - ret.BindArgs(args) - - if vTable, ok := vts[vtName]; ok { - shards := rule.DatabaseTables{} - // compute all tables - topology := vTable.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - dbName, tblName := shards.Smallest() - ret.Database = dbName - ret.Table = tblName - ret.Column = stmt.Column + return nil, perrors.Errorf("optimize: no handler found for '%s'", o.stmt.Mode()) } - return ret, nil + return h(ctx, o) } -func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast.ExpressionNode, args []interface{}) (rule.DatabaseTables, error) { +func (o optimizer) computeShards(table rast.TableName, where rast.ExpressionNode, args []interface{}) (rule.DatabaseTables, error) { + ru := o.rule vt, ok := ru.VTable(table.Suffix()) if !ok { return nil, nil @@ -1041,14 +134,14 @@ func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast shards, fullScan, err := (*Sharder)(ru).Shard(table, where, args...) if err != nil { - return nil, errors.Wrap(err, "calculate shards failed") + return nil, perrors.Wrap(err, "calculate shards failed") } log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) // return error if full-scan is disabled if fullScan && !vt.AllowFullScan() { - return nil, errors.WithStack(errDenyFullScan) + return nil, perrors.WithStack(errDenyFullScan) } if shards.IsEmpty() { @@ -1070,97 +163,3 @@ func (o optimizer) computeShards(ru *rule.Rule, table rast.TableName, where rast return shards, nil } - -func (o optimizer) rewriteSelectStatement(ctx context.Context, conn proto.VConn, stmt *rast.SelectStatement, - db, tb string) error { - // todo db 计算逻辑&tb shard 的计算逻辑 - var starExpand = false - if len(stmt.Select) == 1 { - if _, ok := stmt.Select[0].(*rast.SelectElementAll); ok { - starExpand = true - } - } - - if starExpand { - if len(tb) < 1 { - tb = stmt.From[0].TableName().Suffix() - } - metaData := o.schemaLoader.Load(ctx, conn, db, []string{tb})[tb] - if metaData == nil || len(metaData.ColumnNames) == 0 { - return errors.Errorf("can not get metadata for db:%s and table:%s", db, tb) - } - selectElements := make([]rast.SelectElement, len(metaData.Columns)) - for i, column := range metaData.ColumnNames { - selectElements[i] = rast.NewSelectElementColumn([]string{column}, "") - } - stmt.Select = selectElements - } - - return nil -} - -func (o optimizer) rewriteInsertStatement(ctx context.Context, conn proto.VConn, stmt *rast.InsertStatement, - db, tb string) error { - metaData := o.schemaLoader.Load(ctx, conn, db, []string{tb})[tb] - if metaData == nil || len(metaData.ColumnNames) == 0 { - return errors.Errorf("can not get metadata for db:%s and table:%s", db, tb) - } - - if len(metaData.ColumnNames) == len(stmt.Columns()) { - // User had explicitly specified every value - return nil - } - columnsMetadata := metaData.Columns - - for _, colName := range stmt.Columns() { - if columnsMetadata[colName].PrimaryKey && columnsMetadata[colName].Generated { - // User had explicitly specified auto-generated primary key column - return nil - } - } - - pkColName := "" - for name, column := range columnsMetadata { - if column.PrimaryKey && column.Generated { - pkColName = name - break - } - } - if len(pkColName) < 1 { - // There's no auto-generated primary key column - return nil - } - - // TODO rewrite columns and add distributed primary key - //stmt.SetColumns(append(stmt.Columns(), pkColName)) - // append value of distributed primary key - //newValues := stmt.Values() - //for _, newValue := range newValues { - // newValue = append(newValue, ) - //} - return nil -} - -func (o optimizer) optimizeTrigger(ctx context.Context, stmt *rast.DropTriggerStatement, args []interface{}) (proto.Plan, error) { - var ru *rule.Rule - if ru = rcontext.Rule(ctx); ru == nil { - return nil, errors.WithStack(errNoRuleFound) - } - - shards := rule.DatabaseTables{} - for _, table := range ru.VTables() { - topology := table.Topology() - topology.Each(func(dbIdx, tbIdx int) bool { - if d, t, ok := topology.Render(dbIdx, tbIdx); ok { - shards[d] = append(shards[d], t) - } - return true - }) - - break - } - - ret := &plan.DropTriggerPlan{Stmt: stmt, Shards: shards} - ret.BindArgs(args) - return ret, nil -} diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index 703af0d2..580cd224 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -36,7 +36,6 @@ import ( "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/resultx" - rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/testdata" ) @@ -58,16 +57,16 @@ func TestOptimizer_OptimizeSelect(t *testing.T) { AnyTimes() var ( - sql = "select id, uid from student where uid in (?,?,?)" - ctx = context.Background() - rule = makeFakeRule(ctrl, 8) - opt optimizer + sql = "select id, uid from student where uid in (?,?,?)" + ctx = context.Background() + ru = makeFakeRule(ctrl, 8) ) p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - - plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 1, 2, 3) + opt, err := NewOptimizer(conn, nil, ru, nil, stmt, []interface{}{1, 2, 3}) + assert.NoError(t, err) + plan, err := opt.Optimize(ctx) assert.NoError(t, err) _, _ = plan.ExecIn(ctx, conn) @@ -130,18 +129,21 @@ func TestOptimizer_OptimizeInsert(t *testing.T) { loader.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeStudentMetadata).Times(2) var ( - ctx = context.Background() - rule = makeFakeRule(ctrl, 8) - opt = optimizer{schemaLoader: loader} + ctx = context.Background() + ru = makeFakeRule(ctrl, 8) ) t.Run("sharding", func(t *testing.T) { + sql := "insert into student(name,uid,age) values('foo',?,18),('bar',?,19),('qux',?,17)" p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 8, 9, 16) // 8,16 -> fake_db_0000, 9 -> fake_db_0001 + opt, err := NewOptimizer(conn, loader, ru, nil, stmt, []interface{}{8, 9, 16}) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) // 8,16 -> fake_db_0000, 9 -> fake_db_0001 assert.NoError(t, err) res, err := plan.ExecIn(ctx, conn) @@ -159,7 +161,10 @@ func TestOptimizer_OptimizeInsert(t *testing.T) { p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - plan, err := opt.Optimize(rcontext.WithRule(ctx, rule), conn, stmt, 1) + opt, err := NewOptimizer(conn, loader, ru, nil, stmt, []interface{}{1}) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) assert.NoError(t, err) res, err := plan.ExecIn(ctx, conn) @@ -187,14 +192,13 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) { }).AnyTimes() var ( - ctx = context.Background() - opt = optimizer{schemaLoader: loader} - ru rule.Rule - tab rule.VTable - topo rule.Topology + ctx = context.Background() + ru rule.Rule + tab rule.VTable + topology rule.Topology ) - topo.SetRender(func(_ int) string { + topology.SetRender(func(_ int) string { return "fake_db" }, func(i int) string { return fmt.Sprintf("student_%04d", i) @@ -203,8 +207,8 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) { for i := 0; i < 8; i++ { tables = append(tables, i) } - topo.SetTopology(0, tables...) - tab.SetTopology(&topo) + topology.SetTopology(0, tables...) + tab.SetTopology(&topology) tab.SetAllowFullScan(true) ru.SetVTable("student", &tab) @@ -214,7 +218,10 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) { p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt) + opt, err := NewOptimizer(conn, loader, &ru, nil, stmt, nil) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) assert.NoError(t, err) _, err = plan.ExecIn(ctx, conn) @@ -228,7 +235,10 @@ func TestOptimizer_OptimizeAlterTable(t *testing.T) { p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt) + opt, err := NewOptimizer(conn, loader, &ru, nil, stmt, nil) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) assert.NoError(t, err) _, err = plan.ExecIn(ctx, conn) @@ -259,17 +269,20 @@ func TestOptimizer_OptimizeInsertSelect(t *testing.T) { var ( ctx = context.Background() ru rule.Rule - opt = optimizer{schemaLoader: loader} ) + ru.SetVTable("student", nil) + t.Run("non-sharding", func(t *testing.T) { sql := "insert into employees(name, age) select name,age from employees_tmp limit 10,2" p := parser.New() stmt, _ := p.ParseOneStmt(sql, "", "") - ru.SetVTable("student", nil) - plan, err := opt.Optimize(rcontext.WithRule(ctx, &ru), conn, stmt, 1) + opt, err := NewOptimizer(conn, loader, &ru, nil, stmt, []interface{}{1}) + assert.NoError(t, err) + + plan, err := opt.Optimize(ctx) assert.NoError(t, err) res, err := plan.ExecIn(ctx, conn) diff --git a/pkg/runtime/optimize/select.go b/pkg/runtime/optimize/select.go new file mode 100644 index 00000000..a2421ccc --- /dev/null +++ b/pkg/runtime/optimize/select.go @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" + "strings" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/merge/aggregator" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/pkg/transformer" + "github.com/arana-db/arana/pkg/util/log" +) + +const ( + _bypass uint32 = 1 << iota + _supported +) + +func init() { + registerOptimizeHandler(ast.SQLTypeSelect, optimizeSelect) +} + +func optimizeSelect(ctx context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.SelectStatement) + + // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be + // `select * from student offset 0 limit 100+5` + originOffset, newLimit := overwriteLimit(stmt, &o.args) + if stmt.HasJoin() { + return optimizeJoin(o, stmt) + } + flag := getSelectFlag(o.rule, stmt) + if flag&_supported == 0 { + return nil, errors.Errorf("unsupported sql: %s", rcontext.SQL(ctx)) + } + + if flag&_bypass != 0 { + if len(stmt.From) > 0 { + err := rewriteSelectStatement(ctx, o, stmt, rcontext.DBGroup(ctx), stmt.From[0].TableName().Suffix()) + if err != nil { + return nil, err + } + } + ret := &plan.SimpleQueryPlan{Stmt: stmt} + ret.BindArgs(o.args) + return ret, nil + } + + var ( + shards rule.DatabaseTables + fullScan bool + err error + vt = o.rule.MustVTable(stmt.From[0].TableName().Suffix()) + ) + + if shards, fullScan, err = (*Sharder)(o.rule).Shard(stmt.From[0].TableName(), stmt.Where, o.args...); err != nil { + return nil, errors.Wrap(err, "calculate shards failed") + } + + log.Debugf("compute shards: result=%s, isFullScan=%v", shards, fullScan) + + // return error if full-scan is disabled + if fullScan && !vt.AllowFullScan() { + return nil, errors.WithStack(errDenyFullScan) + } + + toSingle := func(db, tbl string) (proto.Plan, error) { + if err := rewriteSelectStatement(ctx, o, stmt, db, tbl); err != nil { + return nil, err + } + ret := &plan.SimpleQueryPlan{ + Stmt: stmt, + Database: db, + Tables: []string{tbl}, + } + ret.BindArgs(o.args) + + return ret, nil + } + + // Go through first table if no shards matched. + // For example: + // SELECT ... FROM xxx WHERE a > 8 and a < 4 + if shards.IsEmpty() { + var ( + db0, tbl0 string + ok bool + ) + if db0, tbl0, ok = vt.Topology().Render(0, 0); !ok { + return nil, errors.Errorf("cannot compute minimal topology from '%s'", stmt.From[0].TableName().Suffix()) + } + + return toSingle(db0, tbl0) + } + + // Handle single shard + if shards.Len() == 1 { + var db, tbl string + for k, v := range shards { + db = k + tbl = v[0] + } + return toSingle(db, tbl) + } + + // Handle multiple shards + + if shards.IsFullScan() { // expand all shards if all shards matched + // init shards + shards = rule.DatabaseTables{} + // compute all tables + topology := vt.Topology() + topology.Each(func(dbIdx, tbIdx int) bool { + if d, t, ok := topology.Render(dbIdx, tbIdx); ok { + shards[d] = append(shards[d], t) + } + return true + }) + } + + plans := make([]proto.Plan, 0, len(shards)) + for k, v := range shards { + next := &plan.SimpleQueryPlan{ + Database: k, + Tables: v, + Stmt: stmt, + } + next.BindArgs(o.args) + plans = append(plans, next) + } + + if len(plans) > 0 { + tempPlan := plans[0].(*plan.SimpleQueryPlan) + if err = rewriteSelectStatement(ctx, o, stmt, tempPlan.Database, tempPlan.Tables[0]); err != nil { + return nil, err + } + } + + var tmpPlan proto.Plan + tmpPlan = &plan.UnionPlan{ + Plans: plans, + } + + if stmt.Limit != nil { + tmpPlan = &plan.LimitPlan{ + ParentPlan: tmpPlan, + OriginOffset: originOffset, + OverwriteLimit: newLimit, + } + } + + orderByItems := optimizeOrderBy(stmt) + + if stmt.OrderBy != nil { + tmpPlan = &plan.OrderPlan{ + ParentPlan: tmpPlan, + OrderByItems: orderByItems, + } + } + + convertOrderByItems := func(origins []*ast.OrderByItem) []dataset.OrderByItem { + var result = make([]dataset.OrderByItem, 0, len(origins)) + for _, origin := range origins { + var columnName string + if cn, ok := origin.Expr.(ast.ColumnNameExpressionAtom); ok { + columnName = cn.Suffix() + } + result = append(result, dataset.OrderByItem{ + Column: columnName, + Desc: origin.Desc, + }) + } + return result + } + if stmt.GroupBy != nil { + return &plan.GroupPlan{ + Plan: tmpPlan, + AggItems: aggregator.LoadAggs(stmt.Select), + OrderByItems: convertOrderByItems(stmt.OrderBy), + }, nil + } else { + // TODO: refactor groupby/orderby/aggregate plan to a unified plan + return &plan.AggregatePlan{ + Plan: tmpPlan, + Combiner: transformer.NewCombinerManager(), + AggrLoader: transformer.LoadAggrs(stmt.Select), + }, nil + } +} + +//optimizeJoin ony support a join b in one db +func optimizeJoin(o *optimizer, stmt *ast.SelectStatement) (proto.Plan, error) { + join := stmt.From[0].Source().(*ast.JoinNode) + + compute := func(tableSource *ast.TableSourceNode) (database, alias string, shardList []string, err error) { + table := tableSource.TableName() + if table == nil { + err = errors.New("must table, not statement or join node") + return + } + alias = tableSource.Alias() + database = table.Prefix() + + shards, err := o.computeShards(table, nil, o.args) + if err != nil { + return + } + //table no shard + if shards == nil { + shardList = append(shardList, table.Suffix()) + return + } + //table shard more than one db + if len(shards) > 1 { + err = errors.New("not support more than one db") + return + } + + for k, v := range shards { + database = k + shardList = v + } + + if alias == "" { + alias = table.Suffix() + } + + return + } + + dbLeft, aliasLeft, shardLeft, err := compute(join.Left) + if err != nil { + return nil, err + } + dbRight, aliasRight, shardRight, err := compute(join.Right) + + if err != nil { + return nil, err + } + + if dbLeft != "" && dbRight != "" && dbLeft != dbRight { + return nil, errors.New("not support more than one db") + } + + joinPan := &plan.SimpleJoinPlan{ + Left: &plan.JoinTable{ + Tables: shardLeft, + Alias: aliasLeft, + }, + Join: join, + Right: &plan.JoinTable{ + Tables: shardRight, + Alias: aliasRight, + }, + Stmt: o.stmt.(*ast.SelectStatement), + } + joinPan.BindArgs(o.args) + + return joinPan, nil +} + +func getSelectFlag(ru *rule.Rule, stmt *ast.SelectStatement) (flag uint32) { + switch len(stmt.From) { + case 1: + from := stmt.From[0] + tn := from.TableName() + + if tn == nil { // only FROM table supported now + return + } + + flag |= _supported + + if len(tn) > 1 { + switch strings.ToLower(tn.Prefix()) { + case "mysql", "information_schema": + flag |= _bypass + return + } + } + if !ru.Has(tn.Suffix()) { + flag |= _bypass + } + case 0: + flag |= _bypass + flag |= _supported + } + return +} + +func optimizeOrderBy(stmt *ast.SelectStatement) []dataset.OrderByItem { + if stmt == nil || stmt.OrderBy == nil { + return nil + } + result := make([]dataset.OrderByItem, 0, len(stmt.OrderBy)) + for _, node := range stmt.OrderBy { + column, _ := node.Expr.(ast.ColumnNameExpressionAtom) + item := dataset.OrderByItem{ + Column: column[0], + Desc: node.Desc, + } + result = append(result, item) + } + return result +} + +func overwriteLimit(stmt *ast.SelectStatement, args *[]interface{}) (originOffset, overwriteLimit int64) { + if stmt == nil || stmt.Limit == nil { + return 0, 0 + } + + offset := stmt.Limit.Offset() + limit := stmt.Limit.Limit() + + // SELECT * FROM student where uid = ? limit ? offset ? + var offsetIndex int64 + var limitIndex int64 + + if stmt.Limit.IsOffsetVar() { + offsetIndex = offset + offset = (*args)[offsetIndex].(int64) + + if !stmt.Limit.IsLimitVar() { + limit = stmt.Limit.Limit() + *args = append(*args, limit) + limitIndex = int64(len(*args) - 1) + } + } + originOffset = offset + + if stmt.Limit.IsLimitVar() { + limitIndex = limit + limit = (*args)[limitIndex].(int64) + + if !stmt.Limit.IsOffsetVar() { + *args = append(*args, int64(0)) + offsetIndex = int64(len(*args) - 1) + } + } + + if stmt.Limit.IsLimitVar() || stmt.Limit.IsOffsetVar() { + if !stmt.Limit.IsLimitVar() { + stmt.Limit.SetLimitVar() + stmt.Limit.SetLimit(limitIndex) + } + if !stmt.Limit.IsOffsetVar() { + stmt.Limit.SetOffsetVar() + stmt.Limit.SetOffset(offsetIndex) + } + + newLimitVar := limit + offset + overwriteLimit = newLimitVar + (*args)[limitIndex] = newLimitVar + (*args)[offsetIndex] = int64(0) + return + } + + stmt.Limit.SetOffset(0) + stmt.Limit.SetLimit(offset + limit) + overwriteLimit = offset + limit + return +} + +func rewriteSelectStatement(ctx context.Context, o *optimizer, stmt *ast.SelectStatement, db, tb string) error { + // todo db 计算逻辑&tb shard 的计算逻辑 + var starExpand = false + if len(stmt.Select) == 1 { + if _, ok := stmt.Select[0].(*ast.SelectElementAll); ok { + starExpand = true + } + } + + if starExpand { + if len(tb) < 1 { + tb = stmt.From[0].TableName().Suffix() + } + metaData := o.schemaLoader.Load(ctx, o.vconn, db, []string{tb})[tb] + if metaData == nil || len(metaData.ColumnNames) == 0 { + return errors.Errorf("can not get metadata for db:%s and table:%s", db, tb) + } + selectElements := make([]ast.SelectElement, len(metaData.Columns)) + for i, column := range metaData.ColumnNames { + selectElements[i] = ast.NewSelectElementColumn([]string{column}, "") + } + stmt.Select = selectElements + } + + return nil +} diff --git a/pkg/runtime/optimize/sharder_test.go b/pkg/runtime/optimize/sharder_test.go index 4a6b6ac8..de0bcbc7 100644 --- a/pkg/runtime/optimize/sharder_test.go +++ b/pkg/runtime/optimize/sharder_test.go @@ -56,7 +56,8 @@ func TestShard(t *testing.T) { {"select * from student where uid = if(PI()<3, 1, ?)", []interface{}{0}, []int{0}}, } { t.Run(it.sql, func(t *testing.T) { - stmt := ast.MustParse(it.sql).(*ast.SelectStatement) + _, rawStmt := ast.MustParse(it.sql) + stmt := rawStmt.(*ast.SelectStatement) result, _, err := (*Sharder)(fakeRule).Shard(stmt.From[0].TableName(), stmt.Where, it.args...) assert.NoError(t, err, "shard failed") diff --git a/pkg/runtime/optimize/show_columns.go b/pkg/runtime/optimize/show_columns.go new file mode 100644 index 00000000..542a259a --- /dev/null +++ b/pkg/runtime/optimize/show_columns.go @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeShowColumns, optimizeShowColumns) +} + +func optimizeShowColumns(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.ShowColumns) + + vts := o.rule.VTables() + vtName := []string(stmt.TableName)[0] + ret := &plan.ShowColumnsPlan{Stmt: stmt} + ret.BindArgs(o.args) + + if vTable, ok := vts[vtName]; ok { + shards := rule.DatabaseTables{} + // compute all tables + topology := vTable.Topology() + topology.Each(func(dbIdx, tbIdx int) bool { + if d, t, ok := topology.Render(dbIdx, tbIdx); ok { + shards[d] = append(shards[d], t) + } + return true + }) + _, tblName := shards.Smallest() + ret.Table = tblName + } + + return ret, nil +} diff --git a/pkg/runtime/optimize/show_create.go b/pkg/runtime/optimize/show_create.go new file mode 100644 index 00000000..c8260218 --- /dev/null +++ b/pkg/runtime/optimize/show_create.go @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeShowCreate, optimizeShowCreate) +} + +func optimizeShowCreate(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.ShowCreate) + + if stmt.Type() != ast.ShowCreateTypeTable { + return nil, errors.Errorf("not support SHOW CREATE %s", stmt.Type().String()) + } + + var ( + ret = plan.NewShowCreatePlan(stmt) + table = stmt.Target() + ) + ret.BindArgs(o.args) + + if vt, ok := o.rule.VTable(table); ok { + // sharding + topology := vt.Topology() + if d, t, ok := topology.Render(0, 0); ok { + ret.Database = d + ret.Table = t + } else { + return nil, errors.Errorf("failed to render table:%s ", table) + } + } else { + ret.Table = table + } + + return ret, nil +} diff --git a/pkg/runtime/optimize/show_databases.go b/pkg/runtime/optimize/show_databases.go new file mode 100644 index 00000000..b6ae5037 --- /dev/null +++ b/pkg/runtime/optimize/show_databases.go @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeShowDatabases, optimizeShowDatabases) +} + +func optimizeShowDatabases(_ context.Context, o *optimizer) (proto.Plan, error) { + ret := &plan.ShowDatabasesPlan{Stmt: o.stmt.(*ast.ShowDatabases)} + ret.BindArgs(o.args) + return ret, nil +} diff --git a/pkg/runtime/optimize/show_index.go b/pkg/runtime/optimize/show_index.go new file mode 100644 index 00000000..455ce23a --- /dev/null +++ b/pkg/runtime/optimize/show_index.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeShowIndex, optimizeShowIndex) +} + +func optimizeShowIndex(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.ShowIndex) + + ret := &plan.ShowIndexPlan{Stmt: stmt} + ret.BindArgs(o.args) + + vt, ok := o.rule.VTable(stmt.TableName.Suffix()) + if !ok { + return ret, nil + } + + shards := rule.DatabaseTables{} + + topology := vt.Topology() + if d, t, ok := topology.Render(0, 0); ok { + shards[d] = append(shards[d], t) + } + ret.Shards = shards + return ret, nil +} diff --git a/pkg/runtime/optimize/show_open_tables.go b/pkg/runtime/optimize/show_open_tables.go new file mode 100644 index 00000000..98ebfaa6 --- /dev/null +++ b/pkg/runtime/optimize/show_open_tables.go @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + rcontext "github.com/arana-db/arana/pkg/runtime/context" + "github.com/arana-db/arana/pkg/runtime/namespace" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/pkg/security" + "github.com/arana-db/arana/pkg/transformer" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeShowOpenTables, optimizeShowOpenTables) +} + +func optimizeShowOpenTables(ctx context.Context, o *optimizer) (proto.Plan, error) { + var invertedIndex map[string]string + for logicalTable, v := range o.rule.VTables() { + t := v.Topology() + t.Each(func(x, y int) bool { + if _, phyTable, ok := t.Render(x, y); ok { + if invertedIndex == nil { + invertedIndex = make(map[string]string) + } + invertedIndex[phyTable] = logicalTable + } + return true + }) + } + + stmt := o.stmt.(*ast.ShowOpenTables) + + clusters := security.DefaultTenantManager().GetClusters(rcontext.Tenant(ctx)) + plans := make([]proto.Plan, 0, len(clusters)) + for _, cluster := range clusters { + ns := namespace.Load(cluster) + // 配置里原子库 都需要执行一次 + groups := ns.DBGroups() + for i := 0; i < len(groups); i++ { + ret := plan.NewShowOpenTablesPlan(stmt) + ret.BindArgs(o.args) + ret.SetInvertedShards(invertedIndex) + ret.SetDatabase(groups[i]) + plans = append(plans, ret) + } + } + + unionPlan := &plan.UnionPlan{ + Plans: plans, + } + + aggregate := &plan.AggregatePlan{ + Plan: unionPlan, + Combiner: transformer.NewCombinerManager(), + AggrLoader: transformer.LoadAggrs(nil), + } + + return aggregate, nil +} diff --git a/pkg/runtime/optimize/show_tables.go b/pkg/runtime/optimize/show_tables.go new file mode 100644 index 00000000..e7f31b63 --- /dev/null +++ b/pkg/runtime/optimize/show_tables.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeShowTables, optimizeShowTables) +} + +func optimizeShowTables(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.ShowTables) + var invertedIndex map[string]string + for logicalTable, v := range o.rule.VTables() { + t := v.Topology() + t.Each(func(x, y int) bool { + if _, phyTable, ok := t.Render(x, y); ok { + if invertedIndex == nil { + invertedIndex = make(map[string]string) + } + invertedIndex[phyTable] = logicalTable + } + return true + }) + } + + ret := plan.NewShowTablesPlan(stmt) + ret.BindArgs(o.args) + ret.SetInvertedShards(invertedIndex) + return ret, nil +} diff --git a/pkg/runtime/optimize/show_variables.go b/pkg/runtime/optimize/show_variables.go new file mode 100644 index 00000000..e60170ee --- /dev/null +++ b/pkg/runtime/optimize/show_variables.go @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeShowVariables, optimizeShowVariables) +} + +func optimizeShowVariables(_ context.Context, o *optimizer) (proto.Plan, error) { + ret := plan.NewShowVariablesPlan(o.stmt.(*ast.ShowVariables)) + ret.BindArgs(o.args) + return ret, nil +} diff --git a/pkg/runtime/optimize/truncate.go b/pkg/runtime/optimize/truncate.go new file mode 100644 index 00000000..dba021ff --- /dev/null +++ b/pkg/runtime/optimize/truncate.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeTruncate, optimizeTruncate) +} + +func optimizeTruncate(_ context.Context, o *optimizer) (proto.Plan, error) { + stmt := o.stmt.(*ast.TruncateStatement) + shards, err := o.computeShards(stmt.Table, nil, o.args) + if err != nil { + return nil, errors.Wrap(err, "failed to optimize TRUNCATE statement") + } + + if shards == nil { + return plan.Transparent(stmt, o.args), nil + } + + ret := plan.NewTruncatePlan(stmt) + ret.BindArgs(o.args) + ret.SetShards(shards) + + return ret, nil +} diff --git a/pkg/runtime/optimize/update.go b/pkg/runtime/optimize/update.go new file mode 100644 index 00000000..f1c08b6d --- /dev/null +++ b/pkg/runtime/optimize/update.go @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package optimize + +import ( + "context" +) + +import ( + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" +) + +func init() { + registerOptimizeHandler(ast.SQLTypeUpdate, optimizeUpdate) +} + +func optimizeUpdate(_ context.Context, o *optimizer) (proto.Plan, error) { + var ( + stmt = o.stmt.(*ast.UpdateStatement) + table = stmt.Table + vt *rule.VTable + ok bool + ) + + // non-sharding update + if vt, ok = o.rule.VTable(table.Suffix()); !ok { + ret := plan.NewUpdatePlan(stmt) + ret.BindArgs(o.args) + return ret, nil + } + + //check update sharding key + for _, element := range stmt.Updated { + if _, _, ok := vt.GetShardMetadata(element.Column.Suffix()); ok { + return nil, errors.New("do not support update sharding key") + } + } + + var ( + shards rule.DatabaseTables + fullScan = true + err error + ) + + // compute shards + if where := stmt.Where; where != nil { + sharder := (*Sharder)(o.rule) + if shards, fullScan, err = sharder.Shard(table, where, o.args...); err != nil { + return nil, errors.Wrap(err, "failed to update") + } + } + + // exit if full-scan is disabled + if fullScan && !vt.AllowFullScan() { + return nil, errDenyFullScan + } + + // must be empty shards (eg: update xxx set ... where 1 = 2 and uid = 1) + if shards.IsEmpty() { + return plan.AlwaysEmptyExecPlan{}, nil + } + + // compute all sharding tables + if shards.IsFullScan() { + // init shards + shards = rule.DatabaseTables{} + // compute all tables + topology := vt.Topology() + topology.Each(func(dbIdx, tbIdx int) bool { + if d, t, ok := topology.Render(dbIdx, tbIdx); ok { + shards[d] = append(shards[d], t) + } + return true + }) + } + + ret := plan.NewUpdatePlan(stmt) + ret.BindArgs(o.args) + ret.SetShards(shards) + + return ret, nil +} diff --git a/pkg/runtime/plan/drop_trigger.go b/pkg/runtime/plan/drop_trigger.go index aa1313b0..7f73a137 100644 --- a/pkg/runtime/plan/drop_trigger.go +++ b/pkg/runtime/plan/drop_trigger.go @@ -23,13 +23,13 @@ import ( ) import ( - "github.com/arana-db/arana/pkg/resultx" "github.com/pkg/errors" ) import ( "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" + "github.com/arana-db/arana/pkg/resultx" "github.com/arana-db/arana/pkg/runtime/ast" ) diff --git a/pkg/runtime/plan/transparent.go b/pkg/runtime/plan/transparent.go index 2cfa82b7..82ec180e 100644 --- a/pkg/runtime/plan/transparent.go +++ b/pkg/runtime/plan/transparent.go @@ -45,8 +45,8 @@ type TransparentPlan struct { func Transparent(stmt rast.Statement, args []interface{}) *TransparentPlan { var typ proto.PlanType switch stmt.Mode() { - case rast.Sinsert, rast.Sdelete, rast.Sreplace, rast.Supdate, rast.Struncate, rast.SdropTable, - rast.SalterTable, rast.DropIndex, rast.CreateIndex: + case rast.SQLTypeInsert, rast.SQLTypeDelete, rast.SQLTypeReplace, rast.SQLTypeUpdate, rast.SQLTypeTruncate, rast.SQLTypeDropTable, + rast.SQLTypeAlterTable, rast.SQLTypeDropIndex, rast.SQLTypeCreateIndex: typ = proto.PlanTypeExec default: typ = proto.PlanTypeQuery diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 6f95968f..94384ab4 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -46,9 +46,11 @@ import ( "github.com/arana-db/arana/pkg/metrics" "github.com/arana-db/arana/pkg/mysql" "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/proto/schema_manager" "github.com/arana-db/arana/pkg/resultx" rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/runtime/namespace" + "github.com/arana-db/arana/pkg/runtime/optimize" "github.com/arana-db/arana/pkg/util/log" "github.com/arana-db/arana/pkg/util/rand2" "github.com/arana-db/arana/third_party/pools" @@ -58,8 +60,11 @@ var ( _ Runtime = (*defaultRuntime)(nil) _ proto.VConn = (*defaultRuntime)(nil) _ proto.VConn = (*compositeTx)(nil) +) - Tracer = otel.Tracer("Runtime") +var ( + Tracer = otel.Tracer("Runtime") + _defaultSchemaLoader proto.SchemaLoader = schema_manager.NewSimpleSchemaLoader() ) var ( @@ -218,10 +223,15 @@ func (tx *compositeTx) Execute(ctx *proto.Context) (res proto.Result, warn uint1 c = ctx.Context ) - c = rcontext.WithRule(c, ru) c = rcontext.WithSQL(c, ctx.GetQuery()) - if plan, err = tx.rt.ns.Optimizer().Optimize(c, tx, ctx.Stmt.StmtNode, args...); err != nil { + var opt proto.Optimizer + if opt, err = optimize.NewOptimizer(tx, _defaultSchemaLoader, ru, ctx.Stmt.Hints, ctx.Stmt.StmtNode, args); err != nil { + err = errors.WithStack(err) + return + } + + if plan, err = opt.Optimize(ctx); err != nil { err = errors.WithStack(err) return } @@ -627,14 +637,20 @@ func (pi *defaultRuntime) Execute(ctx *proto.Context) (res proto.Result, warn ui c = ctx.Context ) - c = rcontext.WithRule(c, ru) c = rcontext.WithSQL(c, ctx.GetQuery()) c = rcontext.WithSchema(c, ctx.Schema) c = rcontext.WithDBGroup(c, pi.ns.DBGroups()[0]) c = rcontext.WithTenant(c, ctx.Tenant) start := time.Now() - if plan, err = pi.ns.Optimizer().Optimize(c, pi, ctx.Stmt.StmtNode, args...); err != nil { + + var opt proto.Optimizer + if opt, err = optimize.NewOptimizer(pi, _defaultSchemaLoader, ru, ctx.Stmt.Hints, ctx.Stmt.StmtNode, args); err != nil { + err = errors.WithStack(err) + return + } + + if plan, err = opt.Optimize(c); err != nil { err = errors.WithStack(err) return } diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 804fbcf9..2686b560 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -30,20 +30,18 @@ import ( import ( "github.com/arana-db/arana/pkg/runtime/namespace" - "github.com/arana-db/arana/testdata" ) func TestLoad(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - opt := testdata.NewMockOptimizer(ctrl) const schemaName = "FakeSchema" rt, err := Load(schemaName) assert.Error(t, err) assert.Nil(t, rt) - _ = namespace.Register(namespace.New(schemaName, opt)) + _ = namespace.Register(namespace.New(schemaName)) defer func() { _ = namespace.Unregister(schemaName) }() diff --git a/test/suite.go b/test/suite.go index adce6217..1a6fa44f 100644 --- a/test/suite.go +++ b/test/suite.go @@ -41,6 +41,12 @@ import ( "github.com/arana-db/arana/testdata" ) +const ( + timeout = "1s" + readTimeout = "3s" + writeTimeout = "5s" +) + type Option func(*MySuite) func WithDevMode() Option { @@ -113,7 +119,15 @@ func (ms *MySuite) DB() *sql.DB { } var ( - dsn = fmt.Sprintf("arana:123456@tcp(127.0.0.1:%d)/employees?timeout=1s&readTimeout=1s&writeTimeout=1s&parseTime=true&loc=Local&charset=utf8mb4,utf8", ms.port) + dsn = fmt.Sprintf( + "arana:123456@tcp(127.0.0.1:%d)/employees?"+ + "timeout=%s&"+ + "readTimeout=%s&"+ + "writeTimeout=%s&"+ + "parseTime=true&"+ + "loc=Local&"+ + "charset=utf8mb4,utf8", + ms.port, timeout, readTimeout, writeTimeout) err error ) diff --git a/testdata/mock_runtime.go b/testdata/mock_runtime.go index 4c8a8422..f37f0016 100644 --- a/testdata/mock_runtime.go +++ b/testdata/mock_runtime.go @@ -1,20 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - // Code generated by MockGen. DO NOT EDIT. // Source: github.com/arana-db/arana/pkg/proto (interfaces: VConn,Plan,Optimizer,DB,SchemaLoader) @@ -28,8 +11,6 @@ import ( ) import ( - ast "github.com/arana-db/parser/ast" - gomock "github.com/golang/mock/gomock" ) @@ -176,23 +157,18 @@ func (m *MockOptimizer) EXPECT() *MockOptimizerMockRecorder { } // Optimize mocks base method. -func (m *MockOptimizer) Optimize(arg0 context.Context, arg1 proto.VConn, arg2 ast.StmtNode, arg3 ...interface{}) (proto.Plan, error) { +func (m *MockOptimizer) Optimize(arg0 context.Context) (proto.Plan, error) { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1, arg2} - for _, a := range arg3 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Optimize", varargs...) + ret := m.ctrl.Call(m, "Optimize", arg0) ret0, _ := ret[0].(proto.Plan) ret1, _ := ret[1].(error) return ret0, ret1 } // Optimize indicates an expected call of Optimize. -func (mr *MockOptimizerMockRecorder) Optimize(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { +func (mr *MockOptimizerMockRecorder) Optimize(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Optimize", reflect.TypeOf((*MockOptimizer)(nil).Optimize), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Optimize", reflect.TypeOf((*MockOptimizer)(nil).Optimize), arg0) } // MockDB is a mock of DB interface.