From f655e1a1a877b762125e2d397b41cbec0c2a241f Mon Sep 17 00:00:00 2001 From: qqxhb <30866940+qqxhb@users.noreply.github.com> Date: Wed, 28 Jun 2023 14:54:39 +0800 Subject: [PATCH] feat: assign or attr map/struct (#894) * feat: assign or attr map/struct * feat: impl BeCond --- do.go | 4 +- field/assign_attr.go | 121 +++++++++++++++++++++++++++++++++++++++++++ field/expr.go | 6 ++- 3 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 field/assign_attr.go diff --git a/do.go b/do.go index 05857a34..06b89013 100644 --- a/do.go +++ b/do.go @@ -392,7 +392,9 @@ func (d *DO) Assign(attrs ...field.AssignExpr) Dao { func (d *DO) attrsValue(attrs []field.AssignExpr) []interface{} { values := make([]interface{}, 0, len(attrs)) for _, attr := range attrs { - if expr, ok := attr.AssignExpr().(clause.Eq); ok { + if expr, ok := attr.AssignExpr().(field.IValues); ok { + values = append(values, expr.Values()) + } else if expr, ok := attr.AssignExpr().(clause.Eq); ok { values = append(values, expr) } } diff --git a/field/assign_attr.go b/field/assign_attr.go new file mode 100644 index 00000000..eab5fb91 --- /dev/null +++ b/field/assign_attr.go @@ -0,0 +1,121 @@ +package field + +import ( + "reflect" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/utils/tests" +) + +var testDB, _ = gorm.Open(tests.DummyDialector{}, nil) + +type IValues interface { + Values() interface{} +} + +type attrs struct { + expr + value interface{} + db *gorm.DB + selectFields []IColumnName + omitFields []IColumnName +} + +func (att *attrs) AssignExpr() expression { + return att +} + +func (att *attrs) BeCond() interface{} { + return att.db.Statement.BuildCondition(att.Values()) +} + +func (att *attrs) Values() interface{} { + if att == nil || att.value == nil { + return nil + } + if len(att.selectFields) == 0 && len(att.omitFields) == 0 { + return att.value + } + values := make(map[string]interface{}) + if value, ok := att.value.(map[string]interface{}); ok { + values = value + } else if value, ok := att.value.(*map[string]interface{}); ok { + values = *value + } else { + reflectValue := reflect.Indirect(reflect.ValueOf(att.value)) + for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { + reflectValue = reflect.Indirect(reflectValue) + } + switch reflectValue.Kind() { + case reflect.Struct: + if err := att.db.Statement.Parse(att.value); err == nil { + ignoreZero := len(att.selectFields) == 0 + for _, f := range att.db.Statement.Schema.Fields { + if f.Readable { + if v, isZero := f.ValueOf(att.db.Statement.Context, reflectValue); !isZero || !ignoreZero { + values[f.DBName] = v + } + } + } + } + } + } + if len(att.selectFields) > 0 { + fm, all := toFieldMap(att.selectFields) + if all { + return values + } + tvs := make(map[string]interface{}, len(fm)) + for fn, vl := range values { + if fm[fn] { + tvs[fn] = vl + } + } + return tvs + } + fm, all := toFieldMap(att.omitFields) + if all { + return map[string]interface{}{} + } + for fn, _ := range fm { + delete(values, fn) + } + return values +} + +func toFieldMap(fields []IColumnName) (fieldsMap map[string]bool, all bool) { + fieldsMap = make(map[string]bool, len(fields)) + for _, f := range fields { + if strings.HasSuffix(string(f.ColumnName()), "*") { + all = true + return + } + fieldsMap[string(f.ColumnName())] = true + } + return +} + +func (att *attrs) Select(fields ...IColumnName) *attrs { + if att == nil || att.db == nil { + return att + } + att.selectFields = fields + return att +} + +func (att *attrs) Omit(fields ...IColumnName) *attrs { + if att == nil || att.db == nil { + return att + } + att.omitFields = fields + return att +} + +func Attrs(attr interface{}) *attrs { + res := &attrs{db: testDB.Debug()} + if attr != nil { + res.value = attr + } + return res +} diff --git a/field/expr.go b/field/expr.go index ff2fefe8..0c368f03 100644 --- a/field/expr.go +++ b/field/expr.go @@ -24,7 +24,7 @@ type Expr interface { Build(clause.Builder) As(alias string) Expr - ColumnName() sql + IColumnName BuildColumn(*gorm.Statement, ...BuildOpt) sql BuildWithArgs(*gorm.Statement) (query sql, args []interface{}) RawExpr() expression @@ -52,6 +52,10 @@ type OrderExpr interface { type expression interface{} +type IColumnName interface { + ColumnName() sql +} + type sql string func (e sql) String() string { return string(e) }