Skip to content

Commit

Permalink
Refact field struct
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Nov 14, 2013
1 parent e4612bd commit c354b0f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 71 deletions.
8 changes: 4 additions & 4 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ func (s *Do) combinedSql() string {
func (s *Do) createTable() *Do {
var sqls []string
for _, field := range s.model.fields("migration") {
if len(field.SqlType()) > 0 {
sqls = append(sqls, field.DbName+" "+field.SqlType())
if len(field.sqlTag()) > 0 {
sqls = append(sqls, field.DbName+" "+field.sqlTag())
}
}

Expand Down Expand Up @@ -701,8 +701,8 @@ func (s *Do) autoMigrate() *Do {
s.sqlVars = []interface{}{}

// If column doesn't exist
if len(column_name) == 0 && len(field.SqlType()) > 0 {
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.SqlType())
if len(column_name) == 0 && len(field.sqlTag()) > 0 {
s.sql = fmt.Sprintf("ALTER TABLE %v ADD %v %v;", s.tableName(), field.DbName, field.sqlTag())
s.exec()
}
}
Expand Down
78 changes: 68 additions & 10 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gorm
import (
"database/sql"
"database/sql/driver"

"time"

"strconv"
Expand All @@ -12,22 +13,50 @@ import (
)

type Field struct {
Name string
Value interface{}
DbName string
AutoCreateTime bool
AutoUpdateTime bool
IsPrimaryKey bool
IsBlank bool
structField reflect.StructField

Name string
Value interface{}
DbName string
AutoCreateTime bool
AutoUpdateTime bool
IsPrimaryKey bool
structField reflect.StructField
modelValue reflect.Value
beforeAssociation bool
afterAssociation bool
foreignKey string
model *Model
}

func (f *Field) SqlType() string {
func (f *Field) isBlank() bool {
value := reflect.ValueOf(f.Value)
switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
return value.Int() == 0
case reflect.String:
return value.String() == ""
case reflect.Slice:
return value.Len() == 0
case reflect.Struct:
time_value, is_time := f.Value.(time.Time)
if is_time {
return time_value.IsZero()
} else {
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
if is_scanner {
return !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), do: f.model.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
return true
}
}
}
}
return false
}

func (f *Field) sqlTag() string {
column := getInterfaceValue(f.Value)
field_value := reflect.ValueOf(f.Value)
switch field_value.Kind() {
Expand Down Expand Up @@ -61,6 +90,35 @@ func (f *Field) SqlType() string {
return typ
}

func (f *Field) parseAssociation() {
field_value := reflect.ValueOf(f.Value)

switch field_value.Kind() {
case reflect.Slice:
foreign_key := f.model.typeName() + "Id"
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
case reflect.Struct:
_, is_time := f.Value.(time.Time)
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)

if !is_scanner && !is_time {
if f.modelValue.FieldByName(f.Name + "Id").IsValid() {
f.foreignKey = f.Name + "Id"
f.beforeAssociation = true
} else {
foreign_key := f.model.typeName() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
f.foreignKey = foreign_key
}
f.afterAssociation = true
}
}
}
}

func parseSqlTag(str string) (typ string, addational_typ string, size int) {
if str == "-" {
typ = str
Expand Down
64 changes: 7 additions & 57 deletions model.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gorm

import (
"database/sql"
"errors"
"go/ast"
"reflect"
Expand Down Expand Up @@ -73,31 +72,7 @@ func (m *Model) fields(operation string) (fields []*Field) {
value := indirect_value.FieldByName(p.Name)
time_value, is_time := value.Interface().(time.Time)
field.model = m

switch value.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32:
field.IsBlank = value.Int() == 0
case reflect.String:
field.IsBlank = value.String() == ""
case reflect.Slice:
field.IsBlank = value.Len() == 0
case reflect.Struct:
if is_time {
field.IsBlank = time_value.IsZero()
} else {
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)

if is_scanner {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
} else {
m := &Model{data: value.Interface(), do: m.do}
fields := m.columnsHasValue("other")
if len(fields) == 0 {
field.IsBlank = true
}
}
}
}
field.modelValue = indirect_value

if is_time {
field.AutoCreateTime = "created_at" == field.DbName
Expand All @@ -113,37 +88,10 @@ func (m *Model) fields(operation string) (fields []*Field) {
value.Set(reflect.ValueOf(time.Now()))
}
}
} else {
field_value := reflect.Indirect(value)

switch field_value.Kind() {
case reflect.Slice:
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type().Elem()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
case reflect.Struct:
_, is_scanner := reflect.New(field_value.Type()).Interface().(sql.Scanner)

if !is_scanner {
if indirect_value.FieldByName(p.Name + "Id").IsValid() {
field.foreignKey = p.Name + "Id"
field.beforeAssociation = true
} else {
foreign_key := typ.Name() + "Id"
if reflect.New(field_value.Type()).Elem().FieldByName(foreign_key).IsValid() {
field.foreignKey = foreign_key
}
field.afterAssociation = true
}
}
}
}

field.structField = p
field.Value = value.Interface()

fields = append(fields, &field)
}
}
Expand All @@ -157,7 +105,7 @@ func (m *Model) fields(operation string) (fields []*Field) {

func (m *Model) columnsHasValue(operation string) (fields []*Field) {
for _, field := range m.fields(operation) {
if !field.IsBlank {
if !field.isBlank() {
fields = append(fields, field)
}
}
Expand Down Expand Up @@ -199,7 +147,7 @@ func (m *Model) columnsAndValues(operation string) map[string]interface{} {

if m.data != nil {
for _, field := range m.fields(operation) {
if !field.IsPrimaryKey && (len(field.SqlType()) > 0) {
if !field.IsPrimaryKey && (len(field.sqlTag()) > 0) {
results[field.DbName] = field.Value
}
}
Expand Down Expand Up @@ -297,7 +245,8 @@ func (m *Model) setValueByColumn(name string, value interface{}, out interface{}

func (m *Model) beforeAssociations() (fields []*Field) {
for _, field := range m.fields("null") {
if field.beforeAssociation && !field.IsBlank {
field.parseAssociation()
if field.beforeAssociation && !field.isBlank() {
fields = append(fields, field)
}
}
Expand All @@ -306,7 +255,8 @@ func (m *Model) beforeAssociations() (fields []*Field) {

func (m *Model) afterAssociations() (fields []*Field) {
for _, field := range m.fields("null") {
if field.afterAssociation && !field.IsBlank {
field.parseAssociation()
if field.afterAssociation && !field.isBlank() {
fields = append(fields, field)
}
}
Expand Down

0 comments on commit c354b0f

Please sign in to comment.