Skip to content

Commit

Permalink
Refactoring API for plugin system
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 26, 2015
1 parent 087b708 commit ce72988
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 43 deletions.
6 changes: 3 additions & 3 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func Create(scope *Scope) {

// execute create sql
if scope.Dialect().SupportLastInsertId() {
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId()
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected()
Expand All @@ -67,10 +67,10 @@ func Create(scope *Scope) {
}
} else {
if primaryField == nil {
if results, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); err != nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
scope.db.RowsAffected, _ = results.RowsAffected()
}
} else if scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
scope.db.RowsAffected = 1
}
}
Expand Down
2 changes: 1 addition & 1 deletion callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func Query(scope *Scope) {
scope.prepareQuerySql()

if !scope.HasError() {
rows, err := scope.DB().Query(scope.Sql, scope.SqlVars...)
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.db.RowsAffected = 0

if scope.Err(err) != nil {
Expand Down
4 changes: 2 additions & 2 deletions common_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (s *commonDialect) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

Expand All @@ -102,7 +102,7 @@ func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName str
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

Expand Down
23 changes: 19 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,25 @@ func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}

func (s *DB) New() *DB {
clone := s.clone()
clone.search = nil
return clone
// NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
}

func (s *DB) FreshDB() *DB {
newDB := &DB{
dialect: s.dialect,
logger: s.logger,
callback: s.parent.callback.clone(),
source: s.source,
values: map[string]interface{}{},
db: s.db,
ModelStructs: map[reflect.Type]*ModelStruct{},
}
newDB.parent = newDB
return newDB
}

// CommonDB Return the underlying sql.DB or sql.Tx instance.
Expand Down
4 changes: 2 additions & 2 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (s *mssql) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_catalog = %v",
newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

Expand All @@ -104,7 +104,7 @@ func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) boo
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

Expand Down
4 changes: 2 additions & 2 deletions mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (s *mysql) HasTable(scope *Scope, tableName string) bool {
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v AND table_schema = %v",
newScope.AddToVars(tableName),
newScope.AddToVars(s.databaseName(scope))))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

Expand All @@ -102,7 +102,7 @@ func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) boo
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

Expand Down
4 changes: 2 additions & 2 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (s *postgres) HasTable(scope *Scope, tableName string) bool {
var count int
newScope := scope.New(nil)
newScope.Raw(fmt.Sprintf("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = %v and table_type = 'BASE TABLE'", newScope.AddToVars(tableName)))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

Expand All @@ -94,7 +94,7 @@ func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string)
newScope.AddToVars(tableName),
newScope.AddToVars(columnName),
))
newScope.DB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
newScope.SqlDB().QueryRow(newScope.Sql, newScope.SqlVars...).Scan(&count)
return count > 0
}

Expand Down
38 changes: 15 additions & 23 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ func (scope *Scope) IndirectValue() reflect.Value {
return *scope.indirectValue
}

// NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search, Value: value}
}

func (scope *Scope) NeedPtr() *Scope {
reflectKind := reflect.ValueOf(scope.Value).Kind()
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
Expand All @@ -52,16 +45,21 @@ func (scope *Scope) NeedPtr() *Scope {

// New create a new Scope without search information
func (scope *Scope) New(value interface{}) *Scope {
return &Scope{db: scope.db, Search: &search{}, Value: value}
return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
}

// NewDB create a new DB without search information
func (scope *Scope) NewDB() *DB {
return scope.db.New()
if scope.db != nil {
db := scope.db.clone()
db.search = nil
return db
}
return nil
}

// DB get *sql.DB
func (scope *Scope) DB() sqlCommon {
// SqlDB return *sql.DB
func (scope *Scope) SqlDB() sqlCommon {
return scope.db.db
}

Expand All @@ -73,9 +71,8 @@ func (scope *Scope) SkipLeft() {
// Quote used to quote database column name according to database dialect
func (scope *Scope) Quote(str string) string {
if strings.Index(str, ".") != -1 {
strs := strings.Split(str, ".")
newStrs := []string{}
for _, str := range strs {
for _, str := range strings.Split(str, ".") {
newStrs = append(newStrs, scope.Dialect().Quote(str))
}
return strings.Join(newStrs, ".")
Expand Down Expand Up @@ -176,13 +173,13 @@ func (scope *Scope) CallMethod(name string, checkError bool) {
case func(s *Scope):
f(scope)
case func(s *DB):
f(scope.db.New())
f(scope.NewDB())
case func() error:
scope.Err(f())
case func(s *Scope) error:
scope.Err(f(scope))
case func(s *DB) error:
scope.Err(f(scope.db.New()))
scope.Err(f(scope.NewDB()))
default:
scope.Err(fmt.Errorf("unsupported function %v", name))
}
Expand Down Expand Up @@ -229,12 +226,7 @@ func (scope *Scope) QuotedTableName() string {
return scope.Search.TableName
}

keys := strings.Split(scope.TableName(), ".")
for i, v := range keys {
keys[i] = scope.Quote(v)
}
return strings.Join(keys, ".")

return scope.Quote(scope.TableName())
}

// CombinedConditionSql get combined condition sql
Expand Down Expand Up @@ -263,7 +255,7 @@ func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc())

if !scope.HasError() {
if result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); err == nil {
scope.db.RowsAffected = count
}
Expand Down Expand Up @@ -308,7 +300,7 @@ func (scope *Scope) Trace(t time.Time) {

// Begin start a transaction
func (scope *Scope) Begin() *Scope {
if db, ok := scope.DB().(sqlDb); ok {
if db, ok := scope.SqlDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.InstanceSet("gorm:started_transaction", true)
Expand Down
4 changes: 2 additions & 2 deletions scope_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
func (scope *Scope) row() *sql.Row {
defer scope.Trace(NowFunc())
scope.prepareQuerySql()
return scope.DB().QueryRow(scope.Sql, scope.SqlVars...)
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
}

func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.Trace(NowFunc())
scope.prepareQuerySql()
return scope.DB().Query(scope.Sql, scope.SqlVars...)
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
}

func (scope *Scope) initialize() *Scope {
Expand Down
4 changes: 2 additions & 2 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ func (s *sqlite3) Quote(key string) string {

func (s *sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int
scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count)
scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='%v';", tableName)).Scan(&count)
return count > 0
}

func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.DB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count)
scope.SqlDB().QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = '%v' AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", tableName, columnName, columnName, columnName, columnName)).Scan(&count)
return count > 0
}

Expand Down

0 comments on commit ce72988

Please sign in to comment.