Skip to content

Commit

Permalink
Set table name handler
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed May 27, 2015
1 parent cbebcf6 commit b96ca76
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 25 deletions.
4 changes: 0 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,3 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
}
}
}

func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) {
s.NewScope(source).GetModelStruct().TableName = handler
}
36 changes: 21 additions & 15 deletions model_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@ import (

var modelStructs = map[reflect.Type]*ModelStruct{}

var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}

type ModelStruct struct {
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
TableName func(*DB) string
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
defaultTableName string
}

func (s ModelStruct) TableName(db *DB) string {
return DefaultTableNameHandler(db, s.defaultTableName)
}

type StructField struct {
Expand Down Expand Up @@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}

// Set tablename
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
if results := fm.Call([]reflect.Value{}); len(results) > 0 {
if name, ok := results[0].Interface().(string); ok {
modelStruct.TableName = func(*DB) string {
return name
}
}
}
type tabler interface {
TableName() string
}

if tabler, ok := reflect.New(scopeType).Interface().(interface {
TableName() string
}); ok {
modelStruct.defaultTableName = tabler.TableName()
} else {
name := ToDBName(scopeType.Name())
if scope.db == nil || !scope.db.parent.singularTable {
Expand All @@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}

modelStruct.TableName = func(*DB) string {
return name
}
modelStruct.defaultTableName = name
}

// Get all fields
Expand Down
7 changes: 1 addition & 6 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,7 @@ func (scope *Scope) TableName() string {
return tabler.TableName(scope.db)
}

if scope.GetModelStruct().TableName != nil {
return scope.GetModelStruct().TableName(scope.db)
}

scope.Err(errors.New("wrong table name"))
return ""
return scope.GetModelStruct().TableName(scope.db)
}

func (scope *Scope) QuotedTableName() (name string) {
Expand Down

0 comments on commit b96ca76

Please sign in to comment.