Skip to content

Commit

Permalink
Merge pull request hulklab#5 from Pearlzju/master
Browse files Browse the repository at this point in the history
support custom validate
  • Loading branch information
kowloonzh authored Sep 27, 2019
2 parents e388446 + e1b40ad commit 3d1327b
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 126 deletions.
137 changes: 134 additions & 3 deletions ctx.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package yago

import (
"fmt"
"log"
"github.com/gin-gonic/gin"
"github.com/hulklab/yago/libs/validator"
"mime/multipart"
"reflect"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -33,7 +36,7 @@ func (c *Ctx) Validate() error {
if router, ok := HttpRouterMap[url]; ok {
rules := router.h.Rules()
labels := router.h.Labels()
check, err := validator.ValidateHttp(c.Context, action, labels, rules)
check, err := ValidateHttp(c, action, labels, rules)
if !check {
return err
}
Expand Down Expand Up @@ -151,7 +154,6 @@ func (c *Ctx) RequestFileContent(key string) ([]byte, error) {
}

func (c *Ctx) SetData(data interface{}) {

c.resp = &ResponseBody{
ErrNo: OK.Code(),
ErrMsg: OK.Error(),
Expand Down Expand Up @@ -181,7 +183,6 @@ func (c *Ctx) SetError(err Err, msgEx ...string) {
}

func (c *Ctx) SetDataOrErr(data interface{}, err Err) {

if err.HasErr() {
c.SetError(err)
return
Expand All @@ -197,3 +198,133 @@ func (c *Ctx) GetResponse() (*ResponseBody, bool) {

return c.resp, true
}

func ValidateHttp(c *Ctx, action string, labels validator.Label, rules []validator.Rule) (bool, error) {
type CustomValidatorFunc = func(c *Ctx, p string) (valid bool, err error)

for _, rule := range rules {
actionMatch := false
if len(rule.On) == 0 {
actionMatch = true
} else {
for _, a := range rule.On {
if a == action {
actionMatch = true
break
}
}
}

if actionMatch {
switch method := rule.Method.(type) {
case int:
return validateByRule(c, labels, rule, method)
case CustomValidatorFunc:
for _, p := range rule.Params {
_, exist := c.Get(p)
if !exist {
continue
}
if valid, err := method(c, p); !valid {
return false, err
}
}
default:
log.Fatalf("not support method: %s", reflect.TypeOf(rule.Method))
}
}
}
return true, nil
}

func validateByRule(c *Ctx, labels validator.Label, rule validator.Rule, method int) (bool, error) {
switch method {
case validator.Required:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
return false, fmt.Errorf("%s 不存在", labels.Get(p))
}
if valid, err := (validator.RequiredValidator{}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case validator.String:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
if valid, err := (validator.StringValidator{Min: int(rule.Min), Max: int(rule.Max)}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case validator.Int:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
pvInt, err := strconv.Atoi(pv.(string))
if err != nil {
return false, fmt.Errorf("%s 不是个整数", labels.Get(p))
}
if valid, err := (validator.IntValidator{Min: int(rule.Min), Max: int(rule.Max)}).Check(pvInt); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case validator.Float:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}

pvFloat, err := strconv.ParseFloat(pv.(string), 64)
if err != nil {
return false, fmt.Errorf("%s 不是个浮点数", labels.Get(p))
}
if valid, err := (validator.FloatValidator{Min: rule.Min, Max: rule.Max}).Check(pvFloat); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case validator.JSON:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
if valid, err := (validator.JSONValidator{}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case validator.IP:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
if valid, err := (validator.IPValidator{}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case validator.Match:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
if valid, err := (validator.MatchValidator{Pattern: rule.Pattern}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
}
return true, nil
}

func getErr(label string, err error, message string) error {
if message == "" {
return fmt.Errorf("%s %s", label, err)
}
return fmt.Errorf("%s %s", label, message)
}
129 changes: 7 additions & 122 deletions libs/validator/check.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
package validator

import (
"fmt"
"github.com/gin-gonic/gin"
"strconv"
)

type Rule struct {
Params []string
Method int
On []string
Min float64
Max float64
Pattern string
Message string
Params []string
Method interface{}
On []string
Min float64
Max float64
Pattern string
Message string
}

type Label map[string]string
Expand All @@ -24,112 +18,3 @@ func (l *Label) Get(key string) string {
}
return key
}

func ValidateHttp(c *gin.Context, action string, labels Label, rules []Rule) (bool, error) {
for _, rule := range rules {
actionMatch := false
if len(rule.On) == 0 {
actionMatch = true
} else {
for _, a := range rule.On {
if a == action {
actionMatch = true
break
}
}
}

if actionMatch {
switch rule.Method {
case Required:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
return false, fmt.Errorf("%s 不存在", labels.Get(p))
}
if valid, err := (RequiredValidator{}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case String:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
if valid, err := (StringValidator{Min: int(rule.Min), Max: int(rule.Max)}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case Int:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
pvInt, err := strconv.Atoi(pv.(string))
if err != nil {
return false, fmt.Errorf("%s 不是个整数", labels.Get(p))
}
if valid, err := (IntValidator{Min: int(rule.Min), Max: int(rule.Max)}).Check(pvInt); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case Float:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}

pvFloat, err := strconv.ParseFloat(pv.(string), 64)
if err != nil {
return false, fmt.Errorf("%s 不是个浮点数", labels.Get(p))
}
if valid, err := (FloatValidator{Min: rule.Min, Max: rule.Max}).Check(pvFloat); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case JSON:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
if valid, err := (JSONValidator{}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case IP:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
if valid, err := (IPValidator{}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
case Match:
for _, p := range rule.Params {
pv, exist := c.Get(p)
if !exist {
continue
}
if valid, err := (MatchValidator{Pattern: rule.Pattern}).Check(pv); !valid {
return false, getErr(labels.Get(p), err, rule.Message)
}
}
}
}
}
return true, nil
}

func getErr(label string, err error, message string) error {
if message == "" {
return fmt.Errorf("%s %s", label, err)
}
return fmt.Errorf("%s %s", label, message)

}
3 changes: 2 additions & 1 deletion libs/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
)

const (
Required = iota
_ = iota
Required
Int
Float
String
Expand Down

0 comments on commit 3d1327b

Please sign in to comment.