Skip to content

Commit

Permalink
feat: implement insert and improve select (arana-db#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjeffcaii authored Mar 28, 2022
1 parent ce46154 commit 2b6582d
Show file tree
Hide file tree
Showing 23 changed files with 793 additions and 187 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/arana-db/arana
go 1.16

require (
github.com/arana-db/parser v0.1.3
github.com/arana-db/parser v0.2.1
github.com/bwmarrin/snowflake v0.3.0
github.com/cespare/xxhash/v2 v2.1.2
github.com/dop251/goja v0.0.0-20220102113305-2298ace6d09d
Expand Down
68 changes: 28 additions & 40 deletions go.sum

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion pkg/executor/redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ func (executor *RedirectExecutor) ExecutorComQuery(ctx *proto.Context) (proto.Re
} else {
res, warn, err = rt.Execute(ctx)
}
case *ast.InsertStmt:
if tx, ok := executor.getTx(ctx); ok {
res, warn, err = tx.Execute(ctx)
} else {
res, warn, err = rt.Execute(ctx)
}
default:
// TODO: mark direct flag temporarily, remove when write-mode is supported for runtime
ctx.Context = rcontext.WithDirect(ctx.Context)
Expand Down Expand Up @@ -213,7 +219,7 @@ func (executor *RedirectExecutor) ExecutorComStmtExecute(ctx *proto.Context) (pr
}

switch ctx.Stmt.StmtNode.(type) {
case *ast.SelectStmt:
case *ast.SelectStmt, *ast.InsertStmt:
default:
ctx.Context = rcontext.WithDirect(ctx.Context)
}
Expand Down
22 changes: 14 additions & 8 deletions pkg/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,15 +563,19 @@ func (l *Listener) ExecuteCommand(c *Conn, ctx *proto.Context) error {
}
return nil
}
rlt := result.(*Result)
if len(rlt.Fields) == 0 {
if len(result.GetFields()) == 0 {
// A successful callback with no fields means that this was a
// DML or other write-only operation.
//
// We should not send any more packets after this, but make sure
// to extract the affected rows and last insert id from the result
// struct here since clients expect it.
return c.writeOKPacket(rlt.AffectedRows, rlt.InsertId, c.StatusFlags, warn)

var (
affected, _ = result.RowsAffected()
insertId, _ = result.LastInsertId()
)
return c.writeOKPacket(affected, insertId, c.StatusFlags, warn)
}
err = c.writeFields(l.capabilities, result)
if err != nil {
Expand Down Expand Up @@ -1204,14 +1208,17 @@ func (c *Conn) writeColumnDefinition(field *Field) error {
// writeFields writes the fields of a Result. It should be called only
// if there are valid Columns in the result.
func (c *Conn) writeFields(capabilities uint32, result proto.Result) error {
var (
fields = result.GetFields()
)

// Send the number of fields first.
rlt := result.(*Result)
if err := c.sendColumnCount(uint64(len(rlt.Fields))); err != nil {
if err := c.sendColumnCount(uint64(len(fields))); err != nil {
return err
}

// Now send each Field.
for _, field := range rlt.Fields {
for _, field := range fields {
fld := field.(*Field)
if err := c.writeColumnDefinition(fld); err != nil {
return err
Expand Down Expand Up @@ -1260,8 +1267,7 @@ func (c *Conn) writeRow(row []*proto.Value) error {

// writeRows sends the rows of a Result.
func (c *Conn) writeRows(result proto.Result) error {
rlt := result.(*Result)
for _, row := range rlt.Rows {
for _, row := range result.GetRows() {
r := row.(*Row)
textRow := TextRow{*r}
values, err := textRow.Decode()
Expand Down
32 changes: 26 additions & 6 deletions pkg/proto/rule/database_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,25 @@ func (dt DatabaseTables) IsConfused() bool {

// Smallest returns the smallest pair of database and table.
func (dt DatabaseTables) Smallest() (db, tbl string) {
for k := range dt {
if db == "" || strings.Compare(k, db) == -1 {
db = k
for k, v := range dt {
for _, it := range v {
if tbl == "" || strings.Compare(it, tbl) == -1 {
tbl = it
db = k
}
}
}
for _, it := range dt[db] {
if tbl == "" || strings.Compare(it, tbl) == -1 {
tbl = it
return
}

// Largest returns the largest pair of database and table.
func (dt DatabaseTables) Largest() (db, tbl string) {
for k, v := range dt {
for _, it := range v {
if tbl == "" || strings.Compare(it, tbl) == 1 {
tbl = it
db = k
}
}
}
return
Expand Down Expand Up @@ -274,6 +285,15 @@ func (dt DatabaseTables) IsEmpty() bool {
return dt != nil && len(dt) == 0
}

// Len returns amount of tables.
func (dt DatabaseTables) Len() int {
var n int
for _, v := range dt {
n += len(v)
}
return n
}

func (dt DatabaseTables) String() string {
if dt.IsFullScan() {
return `["*"]`
Expand Down
38 changes: 38 additions & 0 deletions pkg/proto/rule/database_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,41 @@ func TestDatabaseTables_IsConfused(t *testing.T) {
})
}
}

func TestDatabaseTables_Largest(t *testing.T) {
type tt struct {
input string
expectDb, expectTbl string
}

for _, it := range []tt{
{"db0:tb0,tb1;db1:tb2,tb3;db2:tb4,tb5", "db2", "tb5"},
{"db2:tb0,tb1;db1:tb2,tb3;db0:tb4,tb5", "db0", "tb5"},
} {
t.Run(it.input, func(t *testing.T) {
dt := parseDatabaseTablesFromString(it.input)
db, tbl := dt.Largest()
assert.Equal(t, it.expectTbl, tbl)
assert.Equal(t, it.expectDb, db)
})
}
}

func TestDatabaseTables_Smallest(t *testing.T) {
type tt struct {
input string
expectDb, expectTbl string
}

for _, it := range []tt{
{"db0:tb0,tb1;db1:tb2,tb3;db2:tb4,tb5", "db0", "tb0"},
{"db2:tb0,tb1;db1:tb2,tb3;db0:tb4,tb5", "db2", "tb0"},
} {
t.Run(it.input, func(t *testing.T) {
dt := parseDatabaseTablesFromString(it.input)
db, tbl := dt.Smallest()
assert.Equal(t, it.expectTbl, tbl)
assert.Equal(t, it.expectDb, db)
})
}
}
135 changes: 105 additions & 30 deletions pkg/proto/rule/range_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,55 +48,125 @@ func TestStepper_Date_After(t *testing.T) {
}

func TestStepper_After(t *testing.T) {
st := Stepper{
N: 2,
U: Unum,
type tt struct {
st Stepper
offset interface{}
expect interface{}
}

date := parseDate("2022-01-01 00:00:00")

for _, it := range []tt{
{Stepper{N: 2, U: Unum}, 2, 4},
{Stepper{N: 2, U: Unum}, int32(2), int32(4)},
{Stepper{N: 2, U: Unum}, int64(2), int64(4)},
{Stepper{N: 1, U: Uhour}, date, parseDate("2022-01-01 01:00:00")},
{Stepper{N: 1, U: Uday}, date, parseDate("2022-01-02 00:00:00")},
{Stepper{N: 1, U: Uweek}, date, parseDate("2022-01-08 00:00:00")},
} {
t.Run(it.st.String(), func(t *testing.T) {
val, err := it.st.After(it.offset)
assert.NoError(t, err)
assert.Equal(t, it.expect, val)
})
}
val, err := st.After(2)
assert.NoError(t, err)
assert.Equal(t, 4, val)
}

func TestStepper_Before(t *testing.T) {
st := Stepper{
N: 1,
U: Unum,
type tt struct {
st Stepper
offset interface{}
expect interface{}
}

date := parseDate("2022-01-01 00:00:00")

for _, it := range []tt{
{Stepper{N: 2, U: Unum}, 4, 2},
{Stepper{N: 2, U: Unum}, int32(4), int32(2)},
{Stepper{N: 2, U: Unum}, int64(4), int64(2)},
{Stepper{N: 1, U: Uhour}, date, parseDate("2021-12-31 23:00:00")},
{Stepper{N: 1, U: Uday}, date, parseDate("2021-12-31 00:00:00")},
{Stepper{N: 1, U: Uweek}, date, parseDate("2021-12-25 00:00:00")},
} {
t.Run(it.st.String(), func(t *testing.T) {
val, err := it.st.Before(it.offset)
assert.NoError(t, err)
assert.Equal(t, it.expect, val)
})
}
val, err := st.Before(2)
assert.NoError(t, err)
assert.Equal(t, 1, val)
}

func TestStepper_Ascend(t *testing.T) {
st := Stepper{
N: 1,
U: Unum,
t.Run("WithNil", func(t *testing.T) {
st := Stepper{N: 1, U: Unum}
_, err := st.Ascend(nil, 1)
assert.Error(t, err)
})

type tt struct {
st Stepper
offset interface{}
n int
expect []interface{}
}

rng, err := st.Ascend(100, 3)
assert.NoError(t, err)
date := parseDate("2022-01-01 00:00:00")

var vals []int
for rng.HasNext() {
vals = append(vals, rng.Next().(int))
for _, it := range []tt{
{Stepper{N: 1, U: Unum}, 100, 3, []interface{}{100, 101, 102}},
{Stepper{N: 1, U: Unum}, int32(100), 3, []interface{}{int32(100), int32(101), int32(102)}},
{Stepper{N: 1, U: Unum}, int64(100), 3, []interface{}{int64(100), int64(101), int64(102)}},
{Stepper{N: 1, U: Uhour}, date, 3, []interface{}{parseDate("2022-01-01 00:00:00"), parseDate("2022-01-01 01:00:00"), parseDate("2022-01-01 02:00:00")}},
{Stepper{N: 1, U: Uday}, date, 3, []interface{}{parseDate("2022-01-01 00:00:00"), parseDate("2022-01-02 00:00:00"), parseDate("2022-01-03 00:00:00")}},
} {
t.Run(it.st.String(), func(t *testing.T) {
rng, err := it.st.Ascend(it.offset, it.n)
assert.NoError(t, err)

var vals []interface{}
for rng.HasNext() {
vals = append(vals, rng.Next())
}
assert.Equal(t, it.expect, vals)
})
}
assert.Equal(t, []int{100, 101, 102}, vals)
}

func TestStepper_Descend(t *testing.T) {
st := Stepper{
N: 1,
U: Unum,
t.Run("WithNil", func(t *testing.T) {
st := Stepper{N: 1, U: Unum}
_, err := st.Descend(nil, 1)
assert.Error(t, err)
})

type tt struct {
st Stepper
offset interface{}
n int
expect []interface{}
}

rng, err := st.Descend(100, 3)
assert.NoError(t, err)
date := parseDate("2022-01-01 00:00:00")

var vals []int
for rng.HasNext() {
vals = append(vals, rng.Next().(int))
for _, it := range []tt{
{Stepper{N: 1, U: Unum}, 100, 3, []interface{}{100, 99, 98}},
{Stepper{N: 1, U: Unum}, int32(100), 3, []interface{}{int32(100), int32(99), int32(98)}},
{Stepper{N: 1, U: Unum}, int64(100), 3, []interface{}{int64(100), int64(99), int64(98)}},
{Stepper{N: 1, U: Uhour}, date, 3, []interface{}{parseDate("2022-01-01 00:00:00"), parseDate("2021-12-31 23:00:00"), parseDate("2021-12-31 22:00:00")}},
{Stepper{N: 1, U: Uday}, date, 3, []interface{}{parseDate("2022-01-01 00:00:00"), parseDate("2021-12-31 00:00:00"), parseDate("2021-12-30 00:00:00")}},
} {
t.Run(it.st.String(), func(t *testing.T) {
rng, err := it.st.Descend(it.offset, it.n)
assert.NoError(t, err)

var vals []interface{}
for rng.HasNext() {
vals = append(vals, rng.Next())
}
assert.Equal(t, it.expect, vals)
})
}
assert.Equal(t, []int{100, 99, 98}, vals)
}

func TestStepUnit_String(t *testing.T) {
Expand All @@ -119,3 +189,8 @@ func TestStepUnit_String(t *testing.T) {
}
assert.Equal(t, "UNKNOWN", StepUnit(0x7F).String())
}

func parseDate(s string) time.Time {
ret, _ := time.Parse("2006-01-02 15:04:05", s)
return ret
}
12 changes: 12 additions & 0 deletions pkg/proto/rule/topology.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,15 @@ func (to *Topology) Render(dbIdx, tblIdx int) (string, string, bool) {
}
return to.dbRender(dbIdx), to.tbRender(tblIdx), true
}

// Each enumerates items in current Topology.
func (to *Topology) Each(onEach func(dbIdx, tbIdx int) (ok bool)) bool {
for d, v := range to.idx {
for _, t := range v {
if !onEach(d, t) {
return false
}
}
}
return true
}
8 changes: 8 additions & 0 deletions pkg/proto/rule/topology_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ func TestRenderForRenderNotNil(t *testing.T) {
assert.True(t, ok)
}

func TestTopology_Each(t *testing.T) {
topology := createTopology()
topology.Each(func(dbIdx, tbIdx int) bool {
t.Logf("on each: %d,%d\n", dbIdx, tbIdx)
return true
})
}

func createTopology() *Topology {
result := &Topology{
dbRender: func(i int) string {
Expand Down
Loading

0 comments on commit 2b6582d

Please sign in to comment.