Skip to content

Commit

Permalink
Fix can't update record with customized primary key
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Nov 11, 2014
1 parent a29ac54 commit 010e7a9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 39 deletions.
25 changes: 17 additions & 8 deletions customize_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ import (
)

type CustomizeColumn struct {
Id int64 `gorm:"column:mapped_id; primary_key:yes"`
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"`
Date time.Time `gorm:"column:mapped_time"`
}

// Make sure an ignored field does not interfere with another field's custom
// column name that matches the ignored field.
type CustomColumnAndIgnoredFieldClash struct {
Body string `sql:"-"`
RawBody string `gorm:"column:body"`
Body string `sql:"-"`
RawBody string `gorm:"column:body"`
}

func TestCustomizeColumn(t *testing.T) {
Expand All @@ -34,16 +34,25 @@ func TestCustomizeColumn(t *testing.T) {
}

expected := "foo"
cc := CustomizeColumn{Id: 666, Name: expected, Date: time.Now()}
cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()}

if count := DB.Save(&cc).RowsAffected; count != 1 {
if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}

var ccs []CustomizeColumn
DB.Find(&ccs)
var cc1 CustomizeColumn
DB.First(&cc1, 666)

if len(ccs) > 0 && ccs[0].Name != expected && ccs[0].Id != 666 {
if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn")
}

cc.Name = "bar"
DB.Save(&cc)

var cc2 CustomizeColumn
DB.First(&cc2, 666)
if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn")
}
}
Expand Down
67 changes: 36 additions & 31 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
)

type Scope struct {
Value interface{}
indirectValue *reflect.Value
Search *search
Sql string
SqlVars []interface{}
db *DB
skipLeft bool
primaryKey string
instanceId string
fields map[string]*Field
Value interface{}
indirectValue *reflect.Value
Search *search
Sql string
SqlVars []interface{}
db *DB
skipLeft bool
primaryKeyField *Field
instanceId string
fields map[string]*Field
}

func (scope *Scope) IndirectValue() reflect.Value {
Expand Down Expand Up @@ -90,27 +90,33 @@ func (scope *Scope) HasError() bool {
return scope.db.Error != nil
}

// PrimaryKey get the primary key's column name
func (scope *Scope) PrimaryKey() string {
if scope.primaryKey != "" {
return scope.primaryKey
}

var indirectValue = scope.IndirectValue()
func (scope *Scope) PrimaryKeyField() *Field {
if scope.primaryKeyField == nil {
var indirectValue = scope.IndirectValue()

clone := scope
if indirectValue.Kind() == reflect.Slice {
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface())
}
clone := scope
if indirectValue.Kind() == reflect.Slice {
clone = scope.New(reflect.New(indirectValue.Type().Elem()).Elem().Interface())
}

for _, field := range clone.Fields() {
if field.IsPrimaryKey {
scope.primaryKey = field.DBName
break
for _, field := range clone.Fields() {
if field.IsPrimaryKey {
scope.primaryKeyField = field
break
}
}
}

return scope.primaryKey
return scope.primaryKeyField
}

// PrimaryKey get the primary key's column name
func (scope *Scope) PrimaryKey() string {
if field := scope.PrimaryKeyField(); field != nil {
return field.DBName
} else {
return ""
}
}

// PrimaryKeyZero check the primary key is blank or not
Expand All @@ -120,12 +126,11 @@ func (scope *Scope) PrimaryKeyZero() bool {

// PrimaryKeyValue get the primary key's value
func (scope *Scope) PrimaryKeyValue() interface{} {
if scope.IndirectValue().Kind() == reflect.Struct {
if field := scope.IndirectValue().FieldByName(SnakeToUpperCamel(scope.PrimaryKey())); field.IsValid() {
return field.Interface()
}
if field := scope.PrimaryKeyField(); field != nil {
return field.Field.Interface()
} else {
return 0
}
return 0
}

// HasColumn to check if has column
Expand Down

0 comments on commit 010e7a9

Please sign in to comment.