Skip to content

Commit

Permalink
refactor: generate gorm default tag (go-gorm#433)
Browse files Browse the repository at this point in the history
* refactor: generate gorm default tag
  • Loading branch information
tr1v3r authored Apr 24, 2022
1 parent d5a2486 commit 853a371
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion internal/model/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ type dataTypeMapping func(detailType string) (finalType string)
type dataTypeMap map[string]dataTypeMapping

func (m dataTypeMap) Get(dataType, detailType string) string {
if convert, ok := m[dataType]; ok {
if convert, ok := m[strings.ToLower(dataType)]; ok {
return convert(detailType)
}
return defaultDataType
Expand Down
2 changes: 1 addition & 1 deletion internal/model/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type FieldConf struct {
FieldCoverable bool // generate pointer when field has default value
FieldSignable bool // detect integer field's unsigned type, adjust generated data type
FieldWithIndexTag bool // generate with gorm index tag
FieldWithTypeTag bool // generate with gorm column type tagl
FieldWithTypeTag bool // generate with gorm column type tag

FieldJSONTagNS func(columnName string) string
FieldNewTagNS func(columnName string) string
Expand Down
29 changes: 17 additions & 12 deletions internal/model/tbl_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (c *Column) ToField(nullable, coverable, signable bool) *Field {
if n, ok := c.Nullable(); ok && n {
fieldType = "*" + fieldType
}
case coverable && c.withDefaultValue():
case coverable && c.needDefaultTag(c.defaultTagValue()):
fieldType = "*" + fieldType
}

Expand Down Expand Up @@ -84,7 +84,10 @@ func (c *Column) multilineComment() bool {
func (c *Column) buildGormTag() string {
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("column:%s;type:%s", c.Name(), c.columnType()))
if p, ok := c.PrimaryKey(); ok && p {

isPriKey, ok := c.PrimaryKey()
isValidPriKey := ok && isPriKey
if isValidPriKey {
buf.WriteString(";primaryKey")
if at, ok := c.AutoIncrement(); ok {
buf.WriteString(fmt.Sprintf(";autoIncrement:%t", at))
Expand All @@ -103,27 +106,29 @@ func (c *Column) buildGormTag() string {
buf.WriteString(fmt.Sprintf(";index:%s,priority:%d", idx.IndexName, idx.SeqInIndex))
}
}
if c.withDefaultValue() {
buf.WriteString(fmt.Sprintf(";default:%s", c.defaultValue()))

if dtValue := c.defaultTagValue(); !isValidPriKey && c.needDefaultTag(dtValue) { // cannot set default tag for primary key
buf.WriteString(fmt.Sprintf(";default:%s", dtValue))
}
return buf.String()
}

// withDefaultValue check if col has default value and not created_at or updated_at
func (c *Column) withDefaultValue() (normal bool) {
return c.defaultValue() != "" && c.defaultValue() != "0" &&
// needDefaultTag check if default tag needed
func (c *Column) needDefaultTag(defaultTagValue string) bool {
return defaultTagValue != "" && defaultTagValue != "0" &&
c.Name() != "created_at" && c.Name() != "updated_at"
}

func (c *Column) defaultValue() string {
df, ok := c.DefaultValue()
// defaultTagValue return gorm default tag's value
func (c *Column) defaultTagValue() string {
value, ok := c.DefaultValue()
if !ok {
return ""
}
if typ := c.DatabaseTypeName(); strings.Contains(typ, "int") || typ == "numeric" || strings.Contains(typ, "float") {
return df
if strings.Contains(value, " ") {
return "'" + value + "'"
}
return "'" + df + "'"
return value
}

func (c *Column) columnType() (v string) {
Expand Down

0 comments on commit 853a371

Please sign in to comment.