diff --git a/README.ZH_CN.md b/README.ZH_CN.md index d26e0c72..357f13e1 100644 --- a/README.ZH_CN.md +++ b/README.ZH_CN.md @@ -258,8 +258,8 @@ FieldNewTag // append new tag FieldNewTagWithNS // specify new tag with name strategy FieldTrimPrefix // trim column prefix FieldTrimSuffix // trim column suffix -FieldAddPrefix // add prefix to struct member's name -FieldAddSuffix // add suffix to struct member's name +FieldAddPrefix // add prefix to struct field's name +FieldAddSuffix // add suffix to struct field's name FieldRelate // specify relationship with other tables FieldRelateModel // specify relationship with exist models ``` diff --git a/README.md b/README.md index c4c8bc3c..ce1fdfcb 100644 --- a/README.md +++ b/README.md @@ -261,8 +261,8 @@ FieldNewTag // append new tag FieldNewTagWithNS // specify new tag with name strategy FieldTrimPrefix // trim column prefix FieldTrimSuffix // trim column suffix -FieldAddPrefix // add prefix to struct member's name -FieldAddSuffix // add suffix to struct member's name +FieldAddPrefix // add prefix to struct field's name +FieldAddSuffix // add suffix to struct field's name FieldRelate // specify relationship with other tables FieldRelateModel // specify relationship with exist models ``` diff --git a/config.go b/config.go index 6d326310..055fd6a9 100644 --- a/config.go +++ b/config.go @@ -37,11 +37,17 @@ type Config struct { Mode GenerateMode // generate mode - queryPkgName string // generated query code's package name - dbNameOpts []model.SchemaNameOpt + queryPkgName string // generated query code's package name + dbNameOpts []model.SchemaNameOpt + + // name strategy for syncing table from db + tableNameNS func(tableName string) (targetTableName string) + modelNameNS func(tableName string) (modelName string) + fileNameNS func(tableName string) (fielName string) + dataTypeMap map[string]func(detailType string) (dataType string) - fieldJSONTagNS func(columnName string) string - fieldNewTagNS func(columnName string) string + fieldJSONTagNS func(columnName string) (tagContent string) + fieldNewTagNS func(columnName string) (tagContent string) } // WithDbNameOpts set get database name function @@ -53,14 +59,32 @@ func (cfg *Config) WithDbNameOpts(opts ...model.SchemaNameOpt) { } } +// WithTableNameStrategy specify table name naming strategy, only work when syncing table from db +func (cfg *Config) WithTableNameStrategy(ns func(tableName string) (targetTableName string)) { + cfg.tableNameNS = ns +} + +// WithModelNameStrategy specify model struct name naming strategy, only work when syncing table from db +func (cfg *Config) WithModelNameStrategy(ns func(tableName string) (modelName string)) { + cfg.modelNameNS = ns +} + +// WithFileNameStrategy specify file name naming strategy, only work when syncing table from db +func (cfg *Config) WithFileNameStrategy(ns func(tableName string) (fielName string)) { + cfg.fileNameNS = ns +} + +// WithDataTypeMap specify data type mapping relationship, only work when syncing table from db func (cfg *Config) WithDataTypeMap(newMap map[string]func(detailType string) (dataType string)) { cfg.dataTypeMap = newMap } +// WithJSONTagNameStrategy specify json tag naming strategy func (cfg *Config) WithJSONTagNameStrategy(ns func(columnName string) (tagContent string)) { cfg.fieldJSONTagNS = ns } +// WithNewTagNameStrategy specify new tag naming strategy func (cfg *Config) WithNewTagNameStrategy(ns func(columnName string) (tagContent string)) { cfg.fieldNewTagNS = ns } diff --git a/field/association.go b/field/association.go index 86bf25a2..57f2b452 100644 --- a/field/association.go +++ b/field/association.go @@ -22,7 +22,7 @@ var ns = schema.NamingStrategy{} type RelationField interface { Name() string Path() string - Field(member ...string) Expr + Field(fields ...string) Expr On(conds ...Expr) RelationField Order(columns ...Expr) RelationField @@ -66,9 +66,9 @@ func (r Relation) RelationshipName() string { return ns.SchemaName(string(r.rela func (r Relation) ChildRelations() []Relation { return r.childRelations } -func (r Relation) Field(member ...string) Expr { - if len(member) > 0 { - return NewString("", r.fieldName+"."+strings.Join(member, ".")).appendBuildOpts(WithoutQuote) +func (r Relation) Field(fields ...string) Expr { + if len(fields) > 0 { + return NewString("", r.fieldName+"."+strings.Join(fields, ".")).appendBuildOpts(WithoutQuote) } return NewString("", r.fieldName).appendBuildOpts(WithoutQuote) } @@ -103,19 +103,18 @@ func (r *Relation) GetOrderCol() []Expr { return r.order } func (r *Relation) GetClauses() []clause.Expression { return r.clauses } func (r *Relation) GetPage() (offset, limit int) { return r.offset, r.limit } -func (r *Relation) StructMember() string { - var memberStr string +func (r *Relation) StructField() (fieldStr string) { for _, relation := range r.childRelations { - memberStr += relation.fieldName + " struct {\nfield.RelationField\n" + relation.StructMember() + "}\n" + fieldStr += relation.fieldName + " struct {\nfield.RelationField\n" + relation.StructField() + "}\n" } - return memberStr + return fieldStr } -func (r *Relation) StructMemberInit() string { +func (r *Relation) StructFieldInit() string { initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.fieldPath, r.fieldType) for _, relation := range r.childRelations { - initStr += relation.fieldName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}" - initStr += "{\n" + relation.StructMemberInit() + "},\n" + initStr += relation.fieldName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructField()) + "}" + initStr += "{\n" + relation.StructFieldInit() + "},\n" } return initStr } diff --git a/field/external_test.go b/field/external_test.go index 13831e6d..8c373159 100644 --- a/field/external_test.go +++ b/field/external_test.go @@ -390,7 +390,7 @@ func BenchmarkExpr_Count(b *testing.B) { } } -func TestRelation_StructMember(t *testing.T) { +func TestRelation_StructField(t *testing.T) { var testdatas = []struct { relation *field.Relation expectedValue string @@ -411,13 +411,13 @@ func TestRelation_StructMember(t *testing.T) { } for _, testdata := range testdatas { - if result := testdata.relation.StructMember(); result != testdata.expectedValue { - t.Errorf("StructMember fail: except %q, got %q", testdata.expectedValue, result) + if result := testdata.relation.StructField(); result != testdata.expectedValue { + t.Errorf("StructField fail: except %q, got %q", testdata.expectedValue, result) } } } -func TestRelation_StructMemberInit(t *testing.T) { +func TestRelation_StructFieldInit(t *testing.T) { var testdatas = []struct { relation *field.Relation expectedValue string @@ -438,8 +438,8 @@ func TestRelation_StructMemberInit(t *testing.T) { } for _, testdata := range testdatas { - if result := testdata.relation.StructMemberInit(); result != testdata.expectedValue { - t.Errorf("StructMember fail: except %q, got %q", testdata.expectedValue, result) + if result := testdata.relation.StructFieldInit(); result != testdata.expectedValue { + t.Errorf("StructField fail: except %q, got %q", testdata.expectedValue, result) } } } diff --git a/field_options.go b/field_options.go index f164ab8d..7a2e28e2 100644 --- a/field_options.go +++ b/field_options.go @@ -16,9 +16,9 @@ var ns = schema.NamingStrategy{} var ( // FieldNew add new field (any type your want) - FieldNew = func(fieldName, fieldType, fieldTag string) model.CreateMemberOpt { - return func(*model.Member) *model.Member { - return &model.Member{ + FieldNew = func(fieldName, fieldType, fieldTag string) model.CreateFieldOpt { + return func(*model.Field) *model.Field { + return &model.Field{ Name: fieldName, Type: fieldType, OverwriteTag: fieldTag, @@ -26,8 +26,8 @@ var ( } } // FieldIgnore ignore some columns by name - FieldIgnore = func(columnNames ...string) model.FilterMemberOpt { - return func(m *model.Member) *model.Member { + FieldIgnore = func(columnNames ...string) model.FilterFieldOpt { + return func(m *model.Field) *model.Field { for _, name := range columnNames { if m.ColumnName == name { return nil @@ -37,12 +37,12 @@ var ( } } // FieldIgnoreReg ignore some columns by RegExp - FieldIgnoreReg = func(columnNameRegs ...string) model.FilterMemberOpt { + FieldIgnoreReg = func(columnNameRegs ...string) model.FilterFieldOpt { regs := make([]regexp.Regexp, len(columnNameRegs)) for i, reg := range columnNameRegs { regs[i] = *regexp.MustCompile(reg) } - return func(m *model.Member) *model.Member { + return func(m *model.Field) *model.Field { for _, reg := range regs { if reg.MatchString(m.ColumnName) { return nil @@ -52,8 +52,8 @@ var ( } } // FieldRename specify field name in generated struct - FieldRename = func(columnName string, newName string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldRename = func(columnName string, newName string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { if m.ColumnName == columnName { m.Name = newName } @@ -61,8 +61,8 @@ var ( } } // FieldType specify field type in generated struct - FieldType = func(columnName string, newType string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldType = func(columnName string, newType string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { if m.ColumnName == columnName { m.Type = newType } @@ -70,9 +70,9 @@ var ( } } // FieldIgnoreType ignore some columns by RegExp - FieldTypeReg = func(columnNameReg string, newType string) model.ModifyMemberOpt { + FieldTypeReg = func(columnNameReg string, newType string) model.ModifyFieldOpt { reg := regexp.MustCompile(columnNameReg) - return func(m *model.Member) *model.Member { + return func(m *model.Field) *model.Field { if reg.MatchString(m.ColumnName) { m.Type = newType } @@ -80,8 +80,8 @@ var ( } } // FieldTag specify GORM tag and JSON tag - FieldTag = func(columnName string, gormTag, jsonTag string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldTag = func(columnName string, gormTag, jsonTag string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { if m.ColumnName == columnName { m.GORMTag, m.JSONTag = gormTag, jsonTag } @@ -89,8 +89,8 @@ var ( } } // FieldJSONTag specify JSON tag - FieldJSONTag = func(columnName string, jsonTag string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldJSONTag = func(columnName string, jsonTag string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { if m.ColumnName == columnName { m.JSONTag = jsonTag } @@ -98,8 +98,8 @@ var ( } } // FieldJSONTagWithNS specify JSON tag with name strategy - FieldJSONTagWithNS = func(schemaName func(columnName string) (tagContent string)) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldJSONTagWithNS = func(schemaName func(columnName string) (tagContent string)) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { if schemaName != nil { m.JSONTag = schemaName(m.ColumnName) } @@ -107,8 +107,8 @@ var ( } } // FieldGORMTag specify GORM tag - FieldGORMTag = func(columnName string, gormTag string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldGORMTag = func(columnName string, gormTag string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { if m.ColumnName == columnName { m.GORMTag = gormTag } @@ -116,8 +116,8 @@ var ( } } // FieldNewTag add new tag - FieldNewTag = func(columnName string, newTag string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldNewTag = func(columnName string, newTag string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { if m.ColumnName == columnName { m.NewTag += " " + newTag } @@ -125,8 +125,8 @@ var ( } } // FieldNewTagWithNS add new tag with name strategy - FieldNewTagWithNS = func(tagName string, schemaName func(columnName string) string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldNewTagWithNS = func(tagName string, schemaName func(columnName string) string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { if schemaName == nil { schemaName = func(name string) string { return name } } @@ -135,43 +135,43 @@ var ( } } // FieldTrimPrefix trim column name's prefix - FieldTrimPrefix = func(prefix string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldTrimPrefix = func(prefix string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { m.Name = strings.TrimPrefix(m.Name, prefix) return m } } // FieldTrimSuffix trim column name's suffix - FieldTrimSuffix = func(suffix string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldTrimSuffix = func(suffix string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { m.Name = strings.TrimSuffix(m.Name, suffix) return m } } // FieldAddPrefix add prefix to struct's memeber name - FieldAddPrefix = func(prefix string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldAddPrefix = func(prefix string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { m.Name = prefix + m.Name return m } } // FieldAddSuffix add suffix to struct's memeber name - FieldAddSuffix = func(suffix string) model.ModifyMemberOpt { - return func(m *model.Member) *model.Member { + FieldAddSuffix = func(suffix string) model.ModifyFieldOpt { + return func(m *model.Field) *model.Field { m.Name += suffix return m } } // FieldRelate relate to table in database - FieldRelate = func(relationship field.RelationshipType, fieldName string, table *check.BaseStruct, config *field.RelateConfig) model.CreateMemberOpt { + FieldRelate = func(relationship field.RelationshipType, fieldName string, table *check.BaseStruct, config *field.RelateConfig) model.CreateFieldOpt { if config == nil { config = &field.RelateConfig{} } if config.JSONTag == "" { config.JSONTag = ns.ColumnName("", fieldName) } - return func(*model.Member) *model.Member { - return &model.Member{ + return func(*model.Field) *model.Field { + return &model.Field{ Name: fieldName, Type: config.RelateFieldPrefix(relationship) + table.StructInfo.Type, JSONTag: config.JSONTag, @@ -186,7 +186,7 @@ var ( } } // FieldRelateModel relate to exist table model - FieldRelateModel = func(relationship field.RelationshipType, fieldName string, relModel interface{}, config *field.RelateConfig) model.CreateMemberOpt { + FieldRelateModel = func(relationship field.RelationshipType, fieldName string, relModel interface{}, config *field.RelateConfig) model.CreateFieldOpt { st := reflect.TypeOf(relModel) if st.Kind() == reflect.Ptr { st = st.Elem() @@ -200,8 +200,8 @@ var ( config.JSONTag = ns.ColumnName("", fieldName) } - return func(*model.Member) *model.Member { - return &model.Member{ + return func(*model.Field) *model.Field { + return &model.Field{ Name: fieldName, Type: config.RelateFieldPrefix(relationship) + fieldType, JSONTag: config.JSONTag, diff --git a/generator.go b/generator.go index 482a41e5..17586770 100644 --- a/generator.go +++ b/generator.go @@ -92,27 +92,32 @@ func (g *Generator) UseDB(db *gorm.DB) { */ // GenerateModel catch table info from db, return a BaseStruct -func (g *Generator) GenerateModel(tableName string, opts ...model.MemberOpt) *check.BaseStruct { +func (g *Generator) GenerateModel(tableName string, opts ...model.FieldOpt) *check.BaseStruct { return g.GenerateModelAs(tableName, g.db.Config.NamingStrategy.SchemaName(tableName), opts...) } // GenerateModel catch table info from db, return a BaseStruct -func (g *Generator) GenerateModelAs(tableName string, modelName string, fieldOpts ...model.MemberOpt) *check.BaseStruct { - tableName = g.tableName(tableName) - s, err := check.GenBaseStructs(g.db, model.DBConf{ +func (g *Generator) GenerateModelAs(tableName string, modelName string, fieldOpts ...model.FieldOpt) *check.BaseStruct { + s, err := check.GenBaseStructs(g.db, model.Conf{ ModelPkg: g.Config.ModelPkgPath, + TablePrefix: g.getTablePrefix(), TableName: tableName, ModelName: modelName, SchemaNameOpts: g.dbNameOpts, - MemberOpts: fieldOpts, - DataTypeMap: g.dataTypeMap, - GenerateModelConfig: model.GenerateModelConfig{ + TableNameNS: g.tableNameNS, + ModelNameNS: g.modelNameNS, + FileNameNS: g.fileNameNS, + FieldConf: model.FieldConf{ + DataTypeMap: g.dataTypeMap, + FieldNullable: g.FieldNullable, FieldWithIndexTag: g.FieldWithIndexTag, FieldWithTypeTag: g.FieldWithTypeTag, FieldJSONTagNS: g.fieldJSONTagNS, FieldNewTagNS: g.fieldNewTagNS, + + FieldOpts: fieldOpts, }, }) if err != nil { @@ -121,23 +126,19 @@ func (g *Generator) GenerateModelAs(tableName string, modelName string, fieldOpt } g.modelData[s.StructName] = s - g.successInfo(fmt.Sprintf("got %d columns from table <%s>", len(s.Members), s.TableName)) + g.successInfo(fmt.Sprintf("got %d columns from table <%s>", len(s.Fields), s.TableName)) return s } -func (g *Generator) tableName(table string) string { +func (g *Generator) getTablePrefix() string { if ns, ok := g.db.NamingStrategy.(schema.NamingStrategy); ok { - if strings.HasPrefix(table, ns.TablePrefix) { - return table - } else { - return ns.TablePrefix + table - } + return ns.TablePrefix } - return table + return "" } // GenerateAllTable generate all tables in db -func (g *Generator) GenerateAllTable(opts ...model.MemberOpt) (tableModels []interface{}) { +func (g *Generator) GenerateAllTable(opts ...model.FieldOpt) (tableModels []interface{}) { tableList, err := g.db.Migrator().GetTables() if err != nil { panic(fmt.Sprintf("get all tables fail: %s", err)) @@ -184,7 +185,7 @@ func (g *Generator) apply(fc interface{}, structs []*check.BaseStruct) { for _, interfaceStruct := range structs { if g.judgeMode(WithoutContext) { - interfaceStruct.ReviseMemberName() + interfaceStruct.ReviseFieldName() } data, err := g.pushBaseStruct(interfaceStruct) @@ -344,8 +345,8 @@ func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) { return err } - defer g.successInfo(fmt.Sprintf("generate query file: %s/%s.gen.go", g.OutPath, strings.ToLower(data.TableName))) - return g.output(fmt.Sprintf("%s/%s.gen.go", g.OutPath, strings.ToLower(data.TableName)), buf.Bytes()) + defer g.successInfo(fmt.Sprintf("generate query file: %s/%s.gen.go", g.OutPath, data.FileName)) + return g.output(fmt.Sprintf("%s/%s.gen.go", g.OutPath, data.FileName), buf.Bytes()) } // generateQueryUnitTestFile generate unit test file for query @@ -369,8 +370,8 @@ func (g *Generator) generateQueryUnitTestFile(data *genInfo) (err error) { } } - defer g.successInfo(fmt.Sprintf("generate unit test file: %s/%s.gen_test.go", g.OutPath, strings.ToLower(data.TableName))) - return g.output(fmt.Sprintf("%s/%s.gen_test.go", g.OutPath, strings.ToLower(data.TableName)), buf.Bytes()) + defer g.successInfo(fmt.Sprintf("generate unit test file: %s/%s.gen_test.go", g.OutPath, data.FileName)) + return g.output(fmt.Sprintf("%s/%s.gen_test.go", g.OutPath, data.FileName), buf.Bytes()) } // generateModelFile generate model structures and save to file @@ -395,7 +396,7 @@ func (g *Generator) generateModelFile() error { return err } - modelFile := modelOutPath + data.TableName + ".gen.go" + modelFile := modelOutPath + data.FileName + ".gen.go" err = g.output(modelFile, buf.Bytes()) if err != nil { return err diff --git a/internal/check/checkinterface.go b/internal/check/checkinterface.go index deb87cbb..100d1ba7 100644 --- a/internal/check/checkinterface.go +++ b/internal/check/checkinterface.go @@ -11,15 +11,15 @@ import ( // InterfaceMethod interface's method type InterfaceMethod struct { - Doc string //comment - S string //First letter of + Doc string // comment + S string // First letter of OriginStruct parser.Param // origin struct name TargetStruct string // generated query struct bane MethodName string // generated function name Params []parser.Param // function input params Result []parser.Param // function output params ResultData parser.Param // output data - Sections *Sections //Parse split SQL into sections + Sections *Sections // Parse split SQL into sections SqlParams []parser.Param // variable in sql need function input SqlString string // SQL GormOption string // gorm execute method Find or Exec or Take @@ -132,10 +132,10 @@ func (m *InterfaceMethod) checkMethod(methods []*InterfaceMethod, s *BaseStruct) m.InterfaceName, m.MethodName, method.InterfaceName, method.MethodName) } } - for _, member := range s.Members { - if member.Name == m.MethodName { - return fmt.Errorf("can not generate method same name with struct member:[%s.%s] and [%s.%s]", - m.InterfaceName, m.MethodName, s.StructName, member.Name) + for _, f := range s.Fields { + if f.Name == m.MethodName { + return fmt.Errorf("can not generate method same name with struct field:[%s.%s] and [%s.%s]", + m.InterfaceName, m.MethodName, s.StructName, f.Name) } } diff --git a/internal/check/checkstruct.go b/internal/check/checkstruct.go index 701b1ef4..d1074077 100644 --- a/internal/check/checkstruct.go +++ b/internal/check/checkstruct.go @@ -18,12 +18,13 @@ type BaseStruct struct { db *gorm.DB GenBaseStruct bool // whether to generate db model + FileName string // generated file name S string // the first letter(lower case)of simple Name NewStructName string // new struct name StructName string // origin struct name - TableName string + TableName string // table name in db server StructInfo parser.Param - Members []*model.Member + Fields []*model.Field Source model.SourceCode } @@ -35,71 +36,70 @@ func (b *BaseStruct) parseStruct(st interface{}) error { return err } b.TableName = stmt.Table + b.FileName = strings.ToLower(stmt.Table) for _, f := range stmt.Schema.Fields { - b.appendOrUpdateMember(&model.Member{ + b.appendOrUpdateField(&model.Field{ Name: f.Name, - Type: b.getMemberRealType(f.FieldType), + Type: b.getFieldRealType(f.FieldType), ColumnName: f.DBName, }) } for _, r := range ParseStructRelationShip(&stmt.Schema.Relationships) { r := r - b.appendOrUpdateMember(&model.Member{Relation: &r}) + b.appendOrUpdateField(&model.Field{Relation: &r}) } return nil } -// getMemberRealType get basic type of member -func (b *BaseStruct) getMemberRealType(member reflect.Type) string { +// getFieldRealType get basic type of field +func (b *BaseStruct) getFieldRealType(f reflect.Type) string { scanValuer := reflect.TypeOf((*field.ScanValuer)(nil)).Elem() - if member.Implements(scanValuer) || reflect.New(member).Type().Implements(scanValuer) { + if f.Implements(scanValuer) || reflect.New(f).Type().Implements(scanValuer) { return "field" } - if member.Kind() == reflect.Ptr { - member = member.Elem() + if f.Kind() == reflect.Ptr { + f = f.Elem() } - if member.String() == "time.Time" { + if f.String() == "time.Time" { return "time.Time" } - if member.String() == "[]uint8" || member.String() == "json.RawMessage" { + if f.String() == "[]uint8" || f.String() == "json.RawMessage" { return "bytes" } - return member.Kind().String() + return f.Kind().String() } -func (b *BaseStruct) ReviseMemberName() { - for _, m := range b.Members { +func (b *BaseStruct) ReviseFieldName() { + for _, m := range b.Fields { m.EscapeKeyword() } } -// check member if in BaseStruct update else append -func (b *BaseStruct) appendOrUpdateMember(member *model.Member) { - if member.IsRelation() { - b.appendMember(member) +// check field if in BaseStruct update else append +func (b *BaseStruct) appendOrUpdateField(f *model.Field) { + if f.IsRelation() { + b.appendField(f) } - if member.ColumnName == "" { + if f.ColumnName == "" { return } - for index, m := range b.Members { - if m.Name == member.Name { - b.Members[index] = member + for index, m := range b.Fields { + if m.Name == f.Name { + b.Fields[index] = f return } } - b.appendMember(member) + b.appendField(f) } -func (b *BaseStruct) appendMember(member *model.Member) { - b.Members = append(b.Members, member) -} +func (b *BaseStruct) appendField(f *model.Field) { b.Fields = append(b.Fields, f) } -// HasMember check if BaseStruct has members -func (b *BaseStruct) HasMember() bool { return len(b.Members) > 0 } +// HasField check if BaseStruct has fields +func (b *BaseStruct) HasField() bool { return len(b.Fields) > 0 } -// check if struct is exportable and if struct in main package and if member's type is regular +// check if struct is exportable and if struct in main package and if field's type is regular func (b *BaseStruct) check() (err error) { if b.StructInfo.InMainPkg() { return fmt.Errorf("can't generated data object for struct in main package, ignore:%s", b.StructName) @@ -107,15 +107,13 @@ func (b *BaseStruct) check() (err error) { if !isCapitalize(b.StructName) { return fmt.Errorf("can't generated data object for non-exportable struct, ignore:%s", b.NewStructName) } - return nil } -func (b *BaseStruct) Relations() []field.Relation { - result := make([]field.Relation, 0, 4) - for _, m := range b.Members { - if m.IsRelation() { - result = append(result, *m.Relation) +func (b *BaseStruct) Relations() (result []field.Relation) { + for _, f := range b.Fields { + if f.IsRelation() { + result = append(result, *f.Relation) } } return result diff --git a/internal/check/gen_structs.go b/internal/check/gen_structs.go index b5ef1887..f6016d53 100644 --- a/internal/check/gen_structs.go +++ b/internal/check/gen_structs.go @@ -24,31 +24,49 @@ const ( ) // GenBaseStructs generate db model by table name -func GenBaseStructs(db *gorm.DB, conf model.DBConf) (bases *BaseStruct, err error) { - modelName, tableName := conf.ModelName, conf.TableName +func GenBaseStructs(db *gorm.DB, conf model.Conf) (bases *BaseStruct, err error) { + modelPkg := conf.ModelPkg + tablePrefix := conf.TablePrefix + tableName := conf.TableName + modelName := conf.ModelName if _, ok := db.Config.Dialector.(tests.DummyDialector); ok { return nil, fmt.Errorf("UseDB() is necessary to generate model struct [%s] from database table [%s]", modelName, tableName) } + if conf.ModelNameNS != nil { + modelName = conf.ModelNameNS(tableName) + } if err = checkModelName(modelName); err != nil { return nil, fmt.Errorf("model name %q is invalid: %w", modelName, err) } - modelPkg := conf.ModelPkg - if modelPkg == "" { - modelPkg = DefaultModelPkg + if conf.TableNameNS != nil { + tableName = conf.TableNameNS(tableName) + } + if !strings.HasPrefix(tableName, tablePrefix) { + tableName = tablePrefix + tableName } - modelPkg = filepath.Base(modelPkg) - columns, err := getTbColumns(db, conf.GetSchemaName(db), tableName, conf.FieldWithIndexTag) + fileName := strings.ToLower(tableName) + if conf.FileNameNS != nil { + fileName = conf.FileNameNS(conf.TableName) + } + + columns, err := getTblColumns(db, conf.GetSchemaName(db), tableName, conf.FieldWithIndexTag) if err != nil { return nil, err } - base := BaseStruct{ - Source: model.TableName, + if modelPkg == "" { + modelPkg = DefaultModelPkg + } + modelPkg = filepath.Base(modelPkg) + + base := &BaseStruct{ + Source: model.Table, GenBaseStruct: true, + FileName: fileName, TableName: tableName, StructName: modelName, NewStructName: uncaptialize(modelName), @@ -57,20 +75,21 @@ func GenBaseStructs(db *gorm.DB, conf model.DBConf) (bases *BaseStruct, err erro } modifyOpts, filterOpts, createOpts := conf.SortOpt() - for _, field := range columns { - field.SetDataTypeMap(conf.DataTypeMap) - field.WithNS(conf.FieldJSONTagNS, conf.FieldNewTagNS) - m := field.ToMember(conf.FieldNullable) + for _, col := range columns { + col.SetDataTypeMap(conf.DataTypeMap) + col.WithNS(conf.FieldJSONTagNS, conf.FieldNewTagNS) + + m := col.ToField(conf.FieldNullable) - if filterMember(m, filterOpts) == nil { + if filterField(m, filterOpts) == nil { continue } if !conf.FieldWithTypeTag { // remove type tag if FieldWithTypeTag == false - m.GORMTag = strings.ReplaceAll(m.GORMTag, ";type:"+field.ColumnType, "") + m.GORMTag = strings.ReplaceAll(m.GORMTag, ";type:"+col.ColumnType, "") } - m = modifyMember(m, modifyOpts) + m = modifyField(m, modifyOpts) if ns, ok := db.NamingStrategy.(schema.NamingStrategy); ok { ns.SingularTable = true m.Name = ns.SchemaName(ns.TablePrefix + m.Name) @@ -78,11 +97,11 @@ func GenBaseStructs(db *gorm.DB, conf model.DBConf) (bases *BaseStruct, err erro m.Name = db.NamingStrategy.SchemaName(m.Name) } - base.Members = append(base.Members, m) + base.Fields = append(base.Fields, m) } for _, create := range createOpts { - m := create.Self()(nil) + m := create.Operator()(nil) if m.Relation != nil { if m.Relation.Model() != nil { @@ -95,24 +114,24 @@ func GenBaseStructs(db *gorm.DB, conf model.DBConf) (bases *BaseStruct, err erro m.Type = strings.ReplaceAll(m.Type, modelPkg+".", "") // remove modelPkg in field's Type, avoid import error } - base.Members = append(base.Members, m) + base.Fields = append(base.Fields, m) } - return &base, nil + return base, nil } -func filterMember(m *model.Member, opts []model.MemberOpt) *model.Member { +func filterField(m *model.Field, opts []model.FieldOpt) *model.Field { for _, opt := range opts { - if opt.Self()(m) == nil { + if opt.Operator()(m) == nil { return nil } } return m } -func modifyMember(m *model.Member, opts []model.MemberOpt) *model.Member { +func modifyField(m *model.Field, opts []model.FieldOpt) *model.Field { for _, opt := range opts { - m = opt.Self()(m) + m = opt.Operator()(m) } return m } diff --git a/internal/check/tb_info.go b/internal/check/tb_info.go index e3543781..1e009628 100644 --- a/internal/check/tb_info.go +++ b/internal/check/tb_info.go @@ -32,7 +32,7 @@ func getITableInfo(db *gorm.DB) ITableInfo { return &mysqlTableInfo{db: db} } -func getTbColumns(db *gorm.DB, schemaName string, tableName string, indexTag bool) (result []*model.Column, err error) { +func getTblColumns(db *gorm.DB, schemaName string, tableName string, indexTag bool) (result []*model.Column, err error) { if db == nil { return nil, errors.New("gorm db is nil") } diff --git a/internal/model/base.go b/internal/model/base.go index add789d1..60220264 100644 --- a/internal/model/base.go +++ b/internal/model/base.go @@ -26,7 +26,7 @@ type SourceCode int const ( Struct SourceCode = iota - TableName + Table ) type KeyWords struct { @@ -125,8 +125,8 @@ func (m dataTypeMap) Get(dataType, detailType string) string { return defaultDataType } -// Member user input structures -type Member struct { +// Field user input structures +type Field struct { Name string Type string ColumnName string @@ -140,9 +140,9 @@ type Member struct { Relation *field.Relation } -func (m *Member) IsRelation() bool { return m.Relation != nil } +func (m *Field) IsRelation() bool { return m.Relation != nil } -func (m *Member) GenType() string { +func (m *Field) GenType() string { if m.IsRelation() { return m.Type } @@ -166,7 +166,7 @@ func (m *Member) GenType() string { } } -func (m *Member) EscapeKeyword() *Member { +func (m *Field) EscapeKeyword() *Field { if GormKeywords.FullMatch(m.Name) { m.Name += "_" } diff --git a/internal/model/conf.go b/internal/model/conf.go new file mode 100644 index 00000000..456f0e9c --- /dev/null +++ b/internal/model/conf.go @@ -0,0 +1,53 @@ +package model + +import ( + "gorm.io/gorm" +) + +// FieldConf field configuration +type FieldConf struct { + DataTypeMap map[string]func(detailType string) (dataType string) + + FieldNullable bool // generate pointer when field is nullable + FieldWithIndexTag bool // generate with gorm index tag + FieldWithTypeTag bool // generate with gorm column type tagl + + FieldJSONTagNS func(columnName string) string + FieldNewTagNS func(columnName string) string + + FieldOpts []FieldOpt +} + +// Conf model configuration +type Conf struct { + ModelPkg string + TablePrefix string + TableName string + ModelName string + + SchemaNameOpts []SchemaNameOpt + TableNameNS func(tableName string) string + ModelNameNS func(tableName string) string + FileNameNS func(tableName string) string + + FieldConf +} + +func (cf *Conf) SortOpt() (modifyOpts []FieldOpt, filterOpts []FieldOpt, createOpts []FieldOpt) { + if cf == nil { + return + } + return sortFieldOpt(cf.FieldOpts) +} + +func (cf *Conf) GetSchemaName(db *gorm.DB) string { + if cf == nil { + return defaultSchemaNameOpt(db) + } + for _, opt := range cf.SchemaNameOpts { + if name := opt(db); name != "" { + return name + } + } + return defaultSchemaNameOpt(db) +} diff --git a/internal/model/db_conf.go b/internal/model/db_conf.go deleted file mode 100644 index 1b37bf15..00000000 --- a/internal/model/db_conf.go +++ /dev/null @@ -1,46 +0,0 @@ -package model - -import ( - "gorm.io/gorm" -) - -type DBConf struct { - ModelPkg string - TableName string - ModelName string - - SchemaNameOpts []SchemaNameOpt - MemberOpts []MemberOpt - - DataTypeMap map[string]func(detailType string) (dataType string) - - GenerateModelConfig -} - -type GenerateModelConfig struct { - FieldNullable bool // generate pointer when field is nullable - FieldWithIndexTag bool // generate with gorm index tag - FieldWithTypeTag bool // generate with gorm column type tagl - - FieldJSONTagNS func(columnName string) string - FieldNewTagNS func(columnName string) string -} - -func (cf *DBConf) SortOpt() (modifyOpts []MemberOpt, filterOpts []MemberOpt, createOpts []MemberOpt) { - if cf == nil { - return - } - return sortOpt(cf.MemberOpts) -} - -func (cf *DBConf) GetSchemaName(db *gorm.DB) string { - if cf == nil { - return defaultMysqlSchemaNameOpt(db) - } - for _, opt := range cf.SchemaNameOpts { - if name := opt(db); name != "" { - return name - } - } - return defaultMysqlSchemaNameOpt(db) -} diff --git a/internal/model/options.go b/internal/model/options.go index 83d41919..959bf5c3 100644 --- a/internal/model/options.go +++ b/internal/model/options.go @@ -6,32 +6,32 @@ import ( type SchemaNameOpt func(*gorm.DB) string -var defaultMysqlSchemaNameOpt = SchemaNameOpt(func(db *gorm.DB) string { +var defaultSchemaNameOpt = SchemaNameOpt(func(db *gorm.DB) string { return db.Migrator().CurrentDatabase() }) -type MemberOpt interface{ Self() func(*Member) *Member } +type FieldOpt interface{ Operator() func(*Field) *Field } -type ModifyMemberOpt func(*Member) *Member +type ModifyFieldOpt func(*Field) *Field -func (o ModifyMemberOpt) Self() func(*Member) *Member { return o } +func (o ModifyFieldOpt) Operator() func(*Field) *Field { return o } -type FilterMemberOpt ModifyMemberOpt +type FilterFieldOpt ModifyFieldOpt -func (o FilterMemberOpt) Self() func(*Member) *Member { return o } +func (o FilterFieldOpt) Operator() func(*Field) *Field { return o } -type CreateMemberOpt ModifyMemberOpt +type CreateFieldOpt ModifyFieldOpt -func (o CreateMemberOpt) Self() func(*Member) *Member { return o } +func (o CreateFieldOpt) Operator() func(*Field) *Field { return o } -func sortOpt(opts []MemberOpt) (modifyOpts []MemberOpt, filterOpts []MemberOpt, createOpts []MemberOpt) { +func sortFieldOpt(opts []FieldOpt) (modifyOpts []FieldOpt, filterOpts []FieldOpt, createOpts []FieldOpt) { for _, opt := range opts { switch opt.(type) { - case ModifyMemberOpt: + case ModifyFieldOpt: modifyOpts = append(modifyOpts, opt) - case FilterMemberOpt: + case FilterFieldOpt: filterOpts = append(filterOpts, opt) - case CreateMemberOpt: + case CreateFieldOpt: createOpts = append(createOpts, opt) } } diff --git a/internal/model/tb_column.go b/internal/model/tbl_column.go similarity index 89% rename from internal/model/tb_column.go rename to internal/model/tbl_column.go index 499be6b1..a7cd7251 100644 --- a/internal/model/tb_column.go +++ b/internal/model/tbl_column.go @@ -36,7 +36,7 @@ func (c *Column) SetDataTypeMap(m map[string]func(detailType string) (dataType s c.dataTypeMap = m } -func (c *Column) GetDataType() (memberType string) { +func (c *Column) GetDataType() (fieldtype string) { if mapping, ok := c.dataTypeMap[c.DataType]; ok { return mapping(c.ColumnType) } @@ -53,16 +53,16 @@ func (c *Column) WithNS(jsonTagNS, newTagNS func(columnName string) string) { } } -func (c *Column) ToMember(nullable bool) *Member { - memberType := c.GetDataType() - if c.ColumnName == "deleted_at" && memberType == "time.Time" { - memberType = "gorm.DeletedAt" +func (c *Column) ToField(nullable bool) *Field { + fieldType := c.GetDataType() + if c.ColumnName == "deleted_at" && fieldType == "time.Time" { + fieldType = "gorm.DeletedAt" } else if nullable && c.IsNullable == "YES" { - memberType = "*" + memberType + fieldType = "*" + fieldType } - return &Member{ + return &Field{ Name: c.ColumnName, - Type: memberType, + Type: fieldType, ColumnName: c.ColumnName, ColumnComment: c.ColumnComment, MultilineComment: c.multilineComment(), diff --git a/internal/model/tb_index.go b/internal/model/tbl_index.go similarity index 100% rename from internal/model/tb_index.go rename to internal/model/tbl_index.go diff --git a/internal/template/model.go b/internal/template/model.go index 0a0b9034..1dad902b 100644 --- a/internal/template/model.go +++ b/internal/template/model.go @@ -10,7 +10,7 @@ const TableName{{.StructName}} = "{{.TableName}}" // {{.StructName}} mapped from table <{{.TableName}}> type {{.StructName}} struct { - {{range .Members}} + {{range .Fields}} {{if .MultilineComment -}} /* {{.ColumnComment}} diff --git a/internal/template/struct.go b/internal/template/struct.go index a620d2f4..2e93485a 100644 --- a/internal/template/struct.go +++ b/internal/template/struct.go @@ -4,14 +4,14 @@ const ( BaseStruct = createMethod + ` type {{.NewStructName}} struct { {{.NewStructName}}Do - ` + members + ` + ` + fields + ` } ` + asMethond + getFieldMethod + fillFieldMapMethod + cloneMethod + relationship + defineMethodStruct BaseStructWithContext = createMethod + ` type {{.NewStructName}} struct { {{.NewStructName}}Do {{.NewStructName}}Do - ` + members + ` + ` + fields + ` } ` + asMethond + ` @@ -30,16 +30,16 @@ const ( _{{.NewStructName}}.{{.NewStructName}}Do.UseDB(db) _{{.NewStructName}}.{{.NewStructName}}Do.UseModel(&{{.StructInfo.Package}}.{{.StructInfo.Type}}{}) - {{if .HasMember}}tableName := _{{.NewStructName}}.{{.NewStructName}}Do.TableName(){{end}} + {{if .HasField}}tableName := _{{.NewStructName}}.{{.NewStructName}}Do.TableName(){{end}} _{{$.NewStructName}}.ALL = field.NewField(tableName, "*") - {{range .Members -}} + {{range .Fields -}} {{if not .IsRelation -}} {{- if .ColumnName -}}_{{$.NewStructName}}.{{.Name}} = field.New{{.GenType}}(tableName, "{{.ColumnName}}"){{- end -}} {{- else -}} _{{$.NewStructName}}.{{.Relation.Name}} = {{$.NewStructName}}{{.Relation.RelationshipName}}{{.Relation.Name}}{ db: db.Session(&gorm.Session{}), - {{.Relation.StructMemberInit}} + {{.Relation.StructFieldInit}} } {{end}} {{end}} @@ -49,9 +49,9 @@ const ( return _{{.NewStructName}} } ` - members = ` + fields = ` ALL field.Field - {{range .Members -}} + {{range .Fields -}} {{if not .IsRelation -}} {{- if .ColumnName -}}{{.Name}} field.{{.GenType}}{{- end -}} {{- else -}} @@ -66,7 +66,7 @@ func ({{.S}} {{.NewStructName}}) As(alias string) *{{.NewStructName}} { {{.S}}.{{.NewStructName}}Do.DO = *({{.S}}.{{.NewStructName}}Do.As(alias).(*gen.DO)) {{.S}}.ALL = field.NewField(alias, "*") - {{range .Members -}} + {{range .Fields -}} {{if not .IsRelation -}} {{- if .ColumnName -}}{{$.S}}.{{.Name}} = field.New{{.GenType}}(alias, "{{.ColumnName}}"){{- end -}} {{end}} @@ -94,7 +94,7 @@ func ({{.S}} *{{.NewStructName}}) GetFieldByName(fieldName string) (field.OrderE return _oe,ok } ` - relationship = `{{range .Members}}{{if .IsRelation}}` + + relationship = `{{range .Fields}}{{if .IsRelation}}` + `{{- $relation := .Relation }}{{- $relationship := $relation.RelationshipName}}` + relationStruct + relationTx + `{{end}}{{end}}` @@ -102,8 +102,8 @@ func ({{.S}} *{{.NewStructName}}) GetFieldByName(fieldName string) (field.OrderE fillFieldMapMethod = ` func ({{.S}} *{{.NewStructName}}) fillFieldMap() { - {{.S}}.fieldMap = make(map[string]field.Expr, {{len .Members}}) - {{range .Members -}} + {{.S}}.fieldMap = make(map[string]field.Expr, {{len .Fields}}) + {{range .Fields -}} {{if not .IsRelation -}} {{- if .ColumnName -}}{{$.S}}.fieldMap["{{.ColumnName}}"] = {{$.S}}.{{.Name}}{{- end -}} {{end}} @@ -119,7 +119,7 @@ type {{$.NewStructName}}{{$relationship}}{{$relation.Name}} struct{ field.RelationField - {{$relation.StructMember}} + {{$relation.StructField}} } func (a {{$.NewStructName}}{{$relationship}}{{$relation.Name}}) Where(conds ...field.Expr) *{{$.NewStructName}}{{$relationship}}{{$relation.Name}} {