Skip to content

Commit

Permalink
Fix get default value from blank primary field
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jun 2, 2016
1 parent dca5e54 commit bf0e236
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,12 @@ func createCallback(scope *Scope) {
for _, field := range scope.Fields() {
if scope.changeableField(field) {
if field.IsNormal {
if !field.IsPrimaryKey || !field.IsBlank {
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else {
columns = append(columns, scope.Quote(field.DBName))
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
}
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
} else if !field.IsPrimaryKey || !field.IsBlank {
columns = append(columns, scope.Quote(field.DBName))
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
}
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _, foreignKey := range field.Relationship.ForeignDBNames {
Expand Down Expand Up @@ -129,7 +127,13 @@ func createCallback(scope *Scope) {
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
func forceReloadAfterCreateCallback(scope *Scope) {
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value)
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
for _, field := range scope.Fields() {
if field.IsPrimaryKey && !field.IsBlank {
db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
}
}
db.Scan(scope.Value)
}
}

Expand Down

0 comments on commit bf0e236

Please sign in to comment.