Skip to content

Commit

Permalink
Close cached prepared stmt when got error
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jul 3, 2020
1 parent 8100ac7 commit f93345a
Showing 1 changed file with 36 additions and 42 deletions.
78 changes: 36 additions & 42 deletions prepare_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,41 +54,38 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
return nil, ErrInvalidTransaction
}

func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(query)
if err == nil {
return stmt.ExecContext(ctx, args...)
} else {
db.mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.mux.Unlock()
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
db.mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.mux.Unlock()
}
}
return nil, err
return result, err
}

func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(query)
if err == nil {
return stmt.QueryContext(ctx, args...)
} else {
db.mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.mux.Unlock()
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
db.mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.mux.Unlock()
}
}
return nil, err
return rows, err
}

func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
} else {
db.mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.mux.Unlock()
}
return &sql.Row{}
}
Expand All @@ -98,41 +95,38 @@ type PreparedStmtTX struct {
PreparedStmtDB *PreparedStmtDB
}

func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
} else {
tx.PreparedStmtDB.mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock()
result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock()
}
}
return nil, err
return result, err
}

func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
} else {
tx.PreparedStmtDB.mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock()
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock()
}
}
return nil, err
return rows, err
}

func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
} else {
tx.PreparedStmtDB.mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock()
}
return &sql.Row{}
}

0 comments on commit f93345a

Please sign in to comment.