Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into shenli/group-concat
Browse files Browse the repository at this point in the history
  • Loading branch information
shenli committed Jan 22, 2016
2 parents a5e79c6 + a59176b commit aa431ee
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 23 deletions.
3 changes: 2 additions & 1 deletion ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func (n *AggregateFuncExpr) Update() error {
return n.updateMaxMin(true)
case AggFuncMin:
return n.updateMaxMin(false)
case AggFuncSum:
case AggFuncSum, AggFuncAvg:
return n.updateSum()
case AggFuncGroupConcat:
return n.updateGroupConcat()
Expand Down Expand Up @@ -534,6 +534,7 @@ func (n *AggregateFuncExpr) updateSum() error {
if err != nil {
return errors.Trace(err)
}
ctx.Count++
return nil
}

Expand Down
13 changes: 13 additions & 0 deletions optimizer/evaluator/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,8 @@ func parseDayInterval(value interface{}) (int64, error) {
func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
name := strings.ToLower(v.F)
switch name {
case ast.AggFuncAvg:
e.evalAggAvg(v)
case ast.AggFuncCount:
e.evalAggCount(v)
case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum, ast.AggFuncGroupConcat:
Expand All @@ -973,3 +975,14 @@ func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
v.SetValue(ctx.Value)
}

func (e *Evaluator) evalAggAvg(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
switch x := ctx.Value.(type) {
case float64:
ctx.Value = x / float64(ctx.Count)
case mysql.Decimal:
ctx.Value = x.Div(mysql.NewDecimalFromUint(uint64(ctx.Count), 0))
}
v.SetValue(ctx.Value)
}
24 changes: 24 additions & 0 deletions optimizer/evaluator/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1054,3 +1054,27 @@ func (s *testEvaluatorSuite) TestColumnNameExpr(c *C) {
c.Assert(err, IsNil)
c.Assert(result, Equals, 2)
}

func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) {
ctx := mock.NewContext()
avg := &ast.AggregateFuncExpr{
F: ast.AggFuncAvg,
}
avg.CurrentGroup = "emptyGroup"
result, err := Eval(ctx, avg)
c.Assert(err, IsNil)
// Empty group should return nil.
c.Assert(result, IsNil)

avg.Args = []ast.ExprNode{ast.NewValueExpr(2)}
avg.Update()
avg.Args = []ast.ExprNode{ast.NewValueExpr(4)}
avg.Update()

result, err = Eval(ctx, avg)
c.Assert(err, IsNil)
expect, _ := mysql.ConvertToDecimal(3)
v, ok := result.(mysql.Decimal)
c.Assert(ok, IsTrue)
c.Assert(v.Equals(expect), IsTrue)
}
36 changes: 15 additions & 21 deletions optimizer/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,10 @@ type supportChecker struct {
}

func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) {
switch ti := in.(type) {
switch x := in.(type) {
case *ast.SubqueryExpr:
c.unsupported = true
case *ast.AggregateFuncExpr:
fn := strings.ToLower(ti.F)
switch fn {
case ast.AggFuncCount, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum, ast.AggFuncGroupConcat:
default:
c.unsupported = true
}
case *ast.Join:
x := in.(*ast.Join)
if x.Right != nil {
c.unsupported = true
} else {
Expand All @@ -93,7 +85,6 @@ func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) {
}
}
case *ast.SelectStmt:
x := in.(*ast.SelectStmt)
if x.Distinct {
c.unsupported = true
}
Expand Down Expand Up @@ -123,25 +114,28 @@ func IsSupported(node ast.Node) bool {

// Optimizer error codes.
const (
CodeOneColumn terror.ErrCode = 1
CodeSameColumns = 2
CodeMultiWildCard = 3
CodeUnsupported = 4
CodeOneColumn terror.ErrCode = 1
CodeSameColumns terror.ErrCode = 2
CodeMultiWildCard terror.ErrCode = 3
CodeUnsupported terror.ErrCode = 4
CodeInvalidGroupFuncUse terror.ErrCode = 5
)

// Optimizer base errors.
var (
ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)")
ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns")
ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once")
ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported")
ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)")
ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns")
ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once")
ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported")
ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(CodeInvalidGroupFuncUse, "Invalid use of group function")
)

func init() {
mySQLErrCodes := map[terror.ErrCode]uint16{
CodeOneColumn: mysql.ErrOperandColumns,
CodeSameColumns: mysql.ErrOperandColumns,
CodeMultiWildCard: mysql.ErrParse,
CodeOneColumn: mysql.ErrOperandColumns,
CodeSameColumns: mysql.ErrOperandColumns,
CodeMultiWildCard: mysql.ErrParse,
CodeInvalidGroupFuncUse: mysql.ErrInvalidGroupFuncUse,
}
terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mySQLErrCodes
}
2 changes: 1 addition & 1 deletion optimizer/typeinferer.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) {
x.SetType(ft)
case ast.AggFuncMax, ast.AggFuncMin:
x.SetType(x.Args[0].GetType())
case ast.AggFuncSum:
case ast.AggFuncSum, ast.AggFuncAvg:
ft := types.NewFieldType(mysql.TypeNewDecimal)
ft.Charset = charset.CharsetBin
ft.Collate = charset.CollationBin
Expand Down
12 changes: 12 additions & 0 deletions optimizer/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,26 @@ type validator struct {
err error
wildCardCount int
inPrepare bool
inAggregate bool
}

func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch in.(type) {
case *ast.AggregateFuncExpr:
if v.inAggregate {
// Aggregate function can not contain aggregate function.
v.err = ErrInvalidGroupFuncUse
return in, true
}
v.inAggregate = true
}
return in, false
}

func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) {
switch x := in.(type) {
case *ast.AggregateFuncExpr:
v.inAggregate = false
case *ast.BetweenExpr:
v.checkAllOneColumn(x.Expr, x.Left, x.Right)
case *ast.BinaryOperationExpr:
Expand Down
12 changes: 12 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,18 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool {
name := strs[0]
host := strs[1]
pwd, err := s.getPassword(name, host)
if err != nil {
if terror.ExecResultIsEmpty.Equal(err) {
log.Errorf("User [%s] not exist %v", name, err)
} else {
log.Errorf("Get User [%s] password from SystemDB error %v", name, err)
}
return false
}
if len(pwd) != 0 && len(pwd) != 40 {
log.Errorf("User [%s] password from SystemDB not like a sha1sum", name)
return false
}
hpwd, err := util.DecodePassword(pwd)
if err != nil {
log.Errorf("Decode password string error %v", err)
Expand Down
7 changes: 7 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,13 @@ func (s *testSessionSuite) TestSession(c *C) {
c.Assert(err, IsNil)
}

func (s *testSessionSuite) TestSessionAuth(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)
defer se.Close()
c.Assert(se.Auth("Any not exist username with zero password! @anyhost", []byte(""), []byte("")), IsFalse)
}

func (s *testSessionSuite) TestErrorRollback(c *C) {
store := newStore(c, s.dbName)
s1 := newSession(c, store, s.dbName)
Expand Down

0 comments on commit aa431ee

Please sign in to comment.