Skip to content

Commit

Permalink
Replace all use of *sql.DB with sqlCommon
Browse files Browse the repository at this point in the history
Exporting sqlCommon as SQLCommon.

This allows passing alternate implementations of the database connection, or wrapping the connection with middleware.  This change didn't change any usages of the database variables.  All usages were already only using the functions defined in SQLCommon.

This does cause a breaking change in Dialect, since *sql.DB was referenced in the interface.
  • Loading branch information
ansel1 committed Mar 14, 2017
1 parent 5409931 commit 45f1a95
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 20 deletions.
4 changes: 2 additions & 2 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type Dialect interface {
GetName() string

// SetDB set db for dialect
SetDB(db *sql.DB)
SetDB(db SQLCommon)

// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
BindVar(i int) string
Expand Down Expand Up @@ -50,7 +50,7 @@ type Dialect interface {

var dialectsMap = map[string]Dialect{}

func newDialect(name string, db *sql.DB) Dialect {
func newDialect(name string, db SQLCommon) Dialect {
if value, ok := dialectsMap[name]; ok {
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
dialect.SetDB(db)
Expand Down
5 changes: 2 additions & 3 deletions dialect_common.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gorm

import (
"database/sql"
"fmt"
"reflect"
"regexp"
Expand All @@ -15,7 +14,7 @@ type DefaultForeignKeyNamer struct {
}

type commonDialect struct {
db *sql.DB
db SQLCommon
DefaultForeignKeyNamer
}

Expand All @@ -27,7 +26,7 @@ func (commonDialect) GetName() string {
return "common"
}

func (s *commonDialect) SetDB(db *sql.DB) {
func (s *commonDialect) SetDB(db SQLCommon) {
s.db = db
}

Expand Down
5 changes: 2 additions & 3 deletions dialects/mssql/mssql.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package mssql

import (
"database/sql"
"fmt"
"reflect"
"strconv"
Expand All @@ -24,15 +23,15 @@ func init() {
}

type mssql struct {
db *sql.DB
db gorm.SQLCommon
gorm.DefaultForeignKeyNamer
}

func (mssql) GetName() string {
return "mssql"
}

func (s *mssql) SetDB(db *sql.DB) {
func (s *mssql) SetDB(db gorm.SQLCommon) {
s.db = db
}

Expand Down
3 changes: 2 additions & 1 deletion interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package gorm

import "database/sql"

type sqlCommon interface {
// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB.
type SQLCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
Expand Down
23 changes: 14 additions & 9 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type DB struct {
RowsAffected int64

// single db
db sqlCommon
db SQLCommon
blockGlobalUpdate bool
logMode int
logger logger
Expand Down Expand Up @@ -47,7 +47,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
return nil, err
}
var source string
var dbSQL *sql.DB
var dbSQL SQLCommon

switch value := args[0].(type) {
case string:
Expand All @@ -59,8 +59,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
source = args[1].(string)
}
dbSQL, err = sql.Open(driver, source)
case *sql.DB:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
case SQLCommon:
dbSQL = value
}

Expand Down Expand Up @@ -90,21 +89,27 @@ func (s *DB) New() *DB {
return clone
}

// Close close current db connection
type closer interface {
Close() error
}

// Close close current db connection. If database connection is not an io.Closer, returns an error.
func (s *DB) Close() error {
if db, ok := s.parent.db.(*sql.DB); ok {
if db, ok := s.parent.db.(closer); ok {
return db.Close()
}
return errors.New("can't close current db")
}

// DB get `*sql.DB` from current connection
// If the underlying database connection is not a *sql.DB, returns nil
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
db, _ := s.db.(*sql.DB)
return db
}

// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon {
func (s *DB) CommonDB() SQLCommon {
return s.db
}

Expand Down Expand Up @@ -449,7 +454,7 @@ func (s *DB) Begin() *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok {
tx, err := db.Begin()
c.db = interface{}(tx).(sqlCommon)
c.db = interface{}(tx).(SQLCommon)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
Expand Down
4 changes: 2 additions & 2 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (scope *Scope) NewDB() *DB {
}

// SQLDB return *sql.DB
func (scope *Scope) SQLDB() sqlCommon {
func (scope *Scope) SQLDB() SQLCommon {
return scope.db.db
}

Expand Down Expand Up @@ -391,7 +391,7 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
func (scope *Scope) Begin() *Scope {
if db, ok := scope.SQLDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.db.db = interface{}(tx).(SQLCommon)
scope.InstanceSet("gorm:started_transaction", true)
}
}
Expand Down

0 comments on commit 45f1a95

Please sign in to comment.