diff --git a/benchmark_test.go b/benchmark_test.go index 7ccb46fcc..8f721139b 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -49,9 +49,9 @@ func initDB(b *testing.B, queries ...string) *sql.DB { for _, query := range queries { if _, err := db.Exec(query); err != nil { if w, ok := err.(MySQLWarnings); ok { - b.Logf("Warning on %q: %v", query, w) + b.Logf("warning on %q: %v", query, w) } else { - b.Fatalf("Error on %q: %v", query, err) + b.Fatalf("error on %q: %v", query, err) } } } diff --git a/driver.go b/driver.go index 965b663b7..899f955fb 100644 --- a/driver.go +++ b/driver.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// Package mysql provides a MySQL driver for Go's database/sql package // // The driver should be used via the database/sql package: // @@ -22,7 +22,7 @@ import ( "net" ) -// This struct is exported to make the driver directly accessible. +// MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct{} diff --git a/driver_test.go b/driver_test.go index 0e9571a59..319d34f91 100644 --- a/driver_test.go +++ b/driver_test.go @@ -78,12 +78,12 @@ type DBTest struct { func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if !available { - t.Skipf("MySQL-Server not running on %s", netAddr) + t.Skipf("MySQL server not running on %s", netAddr) } db, err := sql.Open("mysql", dsn) if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) + t.Fatalf("error connecting: %s", err.Error()) } defer db.Close() @@ -94,7 +94,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { db2, err = sql.Open("mysql", dsn2) if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) + t.Fatalf("error connecting: %s", err.Error()) } defer db2.Close() } @@ -115,13 +115,13 @@ func (dbt *DBTest) fail(method, query string, err error) { if len(query) > 300 { query = "[query too large to print]" } - dbt.Fatalf("Error on %s %s: %s", method, query, err.Error()) + dbt.Fatalf("error on %s %s: %s", method, query, err.Error()) } func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { res, err := dbt.db.Exec(query, args...) if err != nil { - dbt.fail("Exec", query, err) + dbt.fail("exec", query, err) } return res } @@ -129,7 +129,7 @@ func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { rows, err := dbt.db.Query(query, args...) if err != nil { - dbt.fail("Query", query, err) + dbt.fail("query", query, err) } return rows } @@ -140,7 +140,7 @@ func TestEmptyQuery(t *testing.T) { rows := dbt.mustQuery("--") // will hang before #255 if rows.Next() { - dbt.Errorf("Next on rows must be false") + dbt.Errorf("next on rows must be false") } }) } @@ -164,7 +164,7 @@ func TestCRUD(t *testing.T) { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { - dbt.Fatalf("Expected 1 affected row, got %d", count) + dbt.Fatalf("expected 1 affected row, got %d", count) } id, err := res.LastInsertId() @@ -172,7 +172,7 @@ func TestCRUD(t *testing.T) { dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) } if id != 0 { - dbt.Fatalf("Expected InsertID 0, got %d", id) + dbt.Fatalf("expected InsertId 0, got %d", id) } // Read @@ -197,7 +197,7 @@ func TestCRUD(t *testing.T) { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { - dbt.Fatalf("Expected 1 affected row, got %d", count) + dbt.Fatalf("expected 1 affected row, got %d", count) } // Check Update @@ -222,7 +222,7 @@ func TestCRUD(t *testing.T) { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 1 { - dbt.Fatalf("Expected 1 affected row, got %d", count) + dbt.Fatalf("expected 1 affected row, got %d", count) } // Check for unexpected rows @@ -232,7 +232,7 @@ func TestCRUD(t *testing.T) { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) } if count != 0 { - dbt.Fatalf("Expected 0 affected row, got %d", count) + dbt.Fatalf("expected 0 affected row, got %d", count) } }) } @@ -653,14 +653,14 @@ func TestNULL(t *testing.T) { dbt.Fatal(err) } if nb.Valid { - dbt.Error("Valid NullBool which should be invalid") + dbt.Error("valid NullBool which should be invalid") } // Valid if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { dbt.Fatal(err) } if !nb.Valid { - dbt.Error("Invalid NullBool which should be valid") + dbt.Error("invalid NullBool which should be valid") } else if nb.Bool != true { dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) } @@ -672,16 +672,16 @@ func TestNULL(t *testing.T) { dbt.Fatal(err) } if nf.Valid { - dbt.Error("Valid NullFloat64 which should be invalid") + dbt.Error("valid NullFloat64 which should be invalid") } // Valid if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { dbt.Fatal(err) } if !nf.Valid { - dbt.Error("Invalid NullFloat64 which should be valid") + dbt.Error("invalid NullFloat64 which should be valid") } else if nf.Float64 != float64(1) { - dbt.Errorf("Unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) + dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) } // NullInt64 @@ -691,16 +691,16 @@ func TestNULL(t *testing.T) { dbt.Fatal(err) } if ni.Valid { - dbt.Error("Valid NullInt64 which should be invalid") + dbt.Error("valid NullInt64 which should be invalid") } // Valid if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { dbt.Fatal(err) } if !ni.Valid { - dbt.Error("Invalid NullInt64 which should be valid") + dbt.Error("invalid NullInt64 which should be valid") } else if ni.Int64 != int64(1) { - dbt.Errorf("Unexpected NullInt64 value: %d (should be 1)", ni.Int64) + dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) } // NullString @@ -710,16 +710,16 @@ func TestNULL(t *testing.T) { dbt.Fatal(err) } if ns.Valid { - dbt.Error("Valid NullString which should be invalid") + dbt.Error("valid NullString which should be invalid") } // Valid if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { dbt.Fatal(err) } if !ns.Valid { - dbt.Error("Invalid NullString which should be valid") + dbt.Error("invalid NullString which should be valid") } else if ns.String != `1` { - dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)") + dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") } // nil-bytes @@ -729,14 +729,14 @@ func TestNULL(t *testing.T) { dbt.Fatal(err) } if b != nil { - dbt.Error("Non-nil []byte wich should be nil") + dbt.Error("non-nil []byte wich should be nil") } // Read non-nil if err = nonNullStmt.QueryRow().Scan(&b); err != nil { dbt.Fatal(err) } if b == nil { - dbt.Error("Nil []byte wich should be non-nil") + dbt.Error("nil []byte wich should be non-nil") } // Insert nil b = nil @@ -745,7 +745,7 @@ func TestNULL(t *testing.T) { dbt.Fatal(err) } if !success { - dbt.Error("Inserting []byte(nil) as NULL failed") + dbt.Error("inserting []byte(nil) as NULL failed") } // Check input==output with input==nil b = nil @@ -753,7 +753,7 @@ func TestNULL(t *testing.T) { dbt.Fatal(err) } if b != nil { - dbt.Error("Non-nil echo from nil input") + dbt.Error("non-nil echo from nil input") } // Check input==output with input!=nil b = []byte("") @@ -820,7 +820,7 @@ func TestUint64(t *testing.T) { sb != shigh, sc != stop, sd != sall: - dbt.Fatal("Unexpected result value") + dbt.Fatal("unexpected result value") } }) } @@ -924,7 +924,7 @@ func TestLoadData(t *testing.T) { } if i != 4 { - dbt.Fatalf("Rows count mismatch. Got %d, want 4", i) + dbt.Fatalf("rows count mismatch. Got %d, want 4", i) } } file, err := ioutil.TempFile("", "gotest") @@ -945,8 +945,8 @@ func TestLoadData(t *testing.T) { // negative test _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") if err == nil { - dbt.Fatal("Load non-existent file didn't fail") - } else if err.Error() != "Local File 'doesnotexist' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" { + dbt.Fatal("load non-existent file didn't fail") + } else if err.Error() != "local file 'doesnotexist' is not registered" { dbt.Fatal(err.Error()) } @@ -966,7 +966,7 @@ func TestLoadData(t *testing.T) { // negative test _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") if err == nil { - dbt.Fatal("Load non-existent Reader didn't fail") + dbt.Fatal("load non-existent Reader didn't fail") } else if err.Error() != "Reader 'doesnotexist' is not registered" { dbt.Fatal(err.Error()) } @@ -1046,7 +1046,7 @@ func TestStrict(t *testing.T) { var checkWarnings = func(err error, mode string, idx int) { if err == nil { - dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in) + dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in) } if warnings, ok := err.(MySQLWarnings); ok { @@ -1055,18 +1055,18 @@ func TestStrict(t *testing.T) { codes[i] = warnings[i].Code } if len(codes) != len(queries[idx].codes) { - dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) + dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) } for i := range warnings { if codes[i] != queries[idx].codes[i] { - dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) + dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes) return } } } else { - dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error()) + dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error()) } } @@ -1082,7 +1082,7 @@ func TestStrict(t *testing.T) { for i := range queries { stmt, err = dbt.db.Prepare(queries[i].in) if err != nil { - dbt.Errorf("Error on preparing query %s: %s", queries[i].in, err.Error()) + dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error()) } _, err = stmt.Exec() @@ -1090,7 +1090,7 @@ func TestStrict(t *testing.T) { err = stmt.Close() if err != nil { - dbt.Errorf("Error on closing stmt for query %s: %s", queries[i].in, err.Error()) + dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error()) } } }) @@ -1100,9 +1100,9 @@ func TestTLS(t *testing.T) { tlsTest := func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { if err == ErrNoTLS { - dbt.Skip("Server does not support TLS") + dbt.Skip("server does not support TLS") } else { - dbt.Fatalf("Error on Ping: %s", err.Error()) + dbt.Fatalf("error on Ping: %s", err.Error()) } } @@ -1115,7 +1115,7 @@ func TestTLS(t *testing.T) { } if value == nil { - dbt.Fatal("No Cipher") + dbt.Fatal("no Cipher") } } } @@ -1132,42 +1132,42 @@ func TestTLS(t *testing.T) { func TestReuseClosedConnection(t *testing.T) { // this test does not use sql.database, it uses the driver directly if !available { - t.Skipf("MySQL-Server not running on %s", netAddr) + t.Skipf("MySQL server not running on %s", netAddr) } md := &MySQLDriver{} conn, err := md.Open(dsn) if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) + t.Fatalf("error connecting: %s", err.Error()) } stmt, err := conn.Prepare("DO 1") if err != nil { - t.Fatalf("Error preparing statement: %s", err.Error()) + t.Fatalf("error preparing statement: %s", err.Error()) } _, err = stmt.Exec(nil) if err != nil { - t.Fatalf("Error executing statement: %s", err.Error()) + t.Fatalf("error executing statement: %s", err.Error()) } err = conn.Close() if err != nil { - t.Fatalf("Error closing connection: %s", err.Error()) + t.Fatalf("error closing connection: %s", err.Error()) } defer func() { if err := recover(); err != nil { - t.Errorf("Panic after reusing a closed connection: %v", err) + t.Errorf("panic after reusing a closed connection: %v", err) } }() _, err = stmt.Exec(nil) if err != nil && err != driver.ErrBadConn { - t.Errorf("Unexpected error '%s', expected '%s'", + t.Errorf("unexpected error '%s', expected '%s'", err.Error(), driver.ErrBadConn.Error()) } } func TestCharset(t *testing.T) { if !available { - t.Skipf("MySQL-Server not running on %s", netAddr) + t.Skipf("MySQL server not running on %s", netAddr) } mustSetCharset := func(charsetParam, expected string) { @@ -1176,14 +1176,14 @@ func TestCharset(t *testing.T) { defer rows.Close() if !rows.Next() { - dbt.Fatalf("Error getting connection charset: %s", rows.Err()) + dbt.Fatalf("error getting connection charset: %s", rows.Err()) } var got string rows.Scan(&got) if got != expected { - dbt.Fatalf("Expected connection charset %s but got %s", expected, got) + dbt.Fatalf("expected connection charset %s but got %s", expected, got) } }) } @@ -1205,14 +1205,14 @@ func TestFailingCharset(t *testing.T) { _, err := dbt.db.Exec("SELECT 1") if err == nil { dbt.db.Close() - t.Fatalf("Connection must not succeed without a valid charset") + t.Fatalf("connection must not succeed without a valid charset") } }) } func TestCollation(t *testing.T) { if !available { - t.Skipf("MySQL-Server not running on %s", netAddr) + t.Skipf("MySQL server not running on %s", netAddr) } defaultCollation := "utf8_general_ci" @@ -1242,7 +1242,7 @@ func TestCollation(t *testing.T) { } if got != expected { - dbt.Fatalf("Expected connection collation %s but got %s", expected, got) + dbt.Fatalf("expected connection collation %s but got %s", expected, got) } }) } @@ -1307,7 +1307,7 @@ func TestTimezoneConversion(t *testing.T) { // Retrieve time from DB rows := dbt.mustQuery("SELECT ts FROM test") if !rows.Next() { - dbt.Fatal("Didn't get any rows out") + dbt.Fatal("did not get any rows out") } var dbTime time.Time @@ -1318,7 +1318,7 @@ func TestTimezoneConversion(t *testing.T) { // Check that dates match if reftime.Unix() != dbTime.Unix() { - dbt.Errorf("Times don't match.\n") + dbt.Errorf("times do not match.\n") dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) dbt.Errorf(" Now(UTC)=%v\n", dbTime) } @@ -1344,7 +1344,7 @@ func TestRowsClose(t *testing.T) { } if rows.Next() { - dbt.Fatal("Unexpected row after rows.Close()") + dbt.Fatal("unexpected row after rows.Close()") } err = rows.Err() @@ -1376,7 +1376,7 @@ func TestCloseStmtBeforeRows(t *testing.T) { } if !rows.Next() { - dbt.Fatal("Getting row failed") + dbt.Fatal("getting row failed") } else { err = rows.Err() if err != nil { @@ -1386,7 +1386,7 @@ func TestCloseStmtBeforeRows(t *testing.T) { var out bool err = rows.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) } if out != true { dbt.Errorf("true != %t", out) @@ -1422,7 +1422,7 @@ func TestStmtMultiRows(t *testing.T) { // 1 if !rows1.Next() { - dbt.Fatal("1st rows1.Next failed") + dbt.Fatal("first rows1.Next failed") } else { err = rows1.Err() if err != nil { @@ -1431,7 +1431,7 @@ func TestStmtMultiRows(t *testing.T) { err = rows1.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) } if out != true { dbt.Errorf("true != %t", out) @@ -1439,7 +1439,7 @@ func TestStmtMultiRows(t *testing.T) { } if !rows2.Next() { - dbt.Fatal("1st rows2.Next failed") + dbt.Fatal("first rows2.Next failed") } else { err = rows2.Err() if err != nil { @@ -1448,7 +1448,7 @@ func TestStmtMultiRows(t *testing.T) { err = rows2.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) } if out != true { dbt.Errorf("true != %t", out) @@ -1457,7 +1457,7 @@ func TestStmtMultiRows(t *testing.T) { // 2 if !rows1.Next() { - dbt.Fatal("2nd rows1.Next failed") + dbt.Fatal("second rows1.Next failed") } else { err = rows1.Err() if err != nil { @@ -1466,14 +1466,14 @@ func TestStmtMultiRows(t *testing.T) { err = rows1.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) } if out != false { dbt.Errorf("false != %t", out) } if rows1.Next() { - dbt.Fatal("Unexpected row on rows1") + dbt.Fatal("unexpected row on rows1") } err = rows1.Close() if err != nil { @@ -1482,7 +1482,7 @@ func TestStmtMultiRows(t *testing.T) { } if !rows2.Next() { - dbt.Fatal("2nd rows2.Next failed") + dbt.Fatal("second rows2.Next failed") } else { err = rows2.Err() if err != nil { @@ -1491,14 +1491,14 @@ func TestStmtMultiRows(t *testing.T) { err = rows2.Scan(&out) if err != nil { - dbt.Fatalf("Error on rows.Scan(): %s", err.Error()) + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) } if out != false { dbt.Errorf("false != %t", out) } if rows2.Next() { - dbt.Fatal("Unexpected row on rows2") + dbt.Fatal("unexpected row on rows2") } err = rows2.Close() if err != nil { @@ -1543,7 +1543,7 @@ func TestConcurrent(t *testing.T) { if err != nil { dbt.Fatalf("%s", err.Error()) } - dbt.Logf("Testing up to %d concurrent connections \r\n", max) + dbt.Logf("testing up to %d concurrent connections \r\n", max) var remaining, succeeded int32 = int32(max), 0 @@ -1567,7 +1567,7 @@ func TestConcurrent(t *testing.T) { if err != nil { if err.Error() != "Error 1040: Too many connections" { - fatalf("Error on Conn %d: %s", id, err.Error()) + fatalf("error on conn %d: %s", id, err.Error()) } return } @@ -1575,13 +1575,13 @@ func TestConcurrent(t *testing.T) { // keep the connection busy until all connections are open for remaining > 0 { if _, err = tx.Exec("DO 1"); err != nil { - fatalf("Error on Conn %d: %s", id, err.Error()) + fatalf("error on conn %d: %s", id, err.Error()) return } } if err = tx.Commit(); err != nil { - fatalf("Error on Conn %d: %s", id, err.Error()) + fatalf("error on conn %d: %s", id, err.Error()) return } @@ -1597,14 +1597,14 @@ func TestConcurrent(t *testing.T) { dbt.Fatal(fatalError) } - dbt.Logf("Reached %d concurrent connections\r\n", succeeded) + dbt.Logf("reached %d concurrent connections\r\n", succeeded) }) } // Tests custom dial functions func TestCustomDial(t *testing.T) { if !available { - t.Skipf("MySQL-Server not running on %s", netAddr) + t.Skipf("MySQL server not running on %s", netAddr) } // our custom dial function which justs wraps net.Dial here @@ -1614,16 +1614,16 @@ func TestCustomDial(t *testing.T) { db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname)) if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) + t.Fatalf("error connecting: %s", err.Error()) } defer db.Close() if _, err = db.Exec("DO 1"); err != nil { - t.Fatalf("Connection failed: %s", err.Error()) + t.Fatalf("connection failed: %s", err.Error()) } } -func TestSqlInjection(t *testing.T) { +func TestSQLInjection(t *testing.T) { createTest := func(arg string) func(dbt *DBTest) { return func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") @@ -1636,9 +1636,9 @@ func TestSqlInjection(t *testing.T) { if err == sql.ErrNoRows { return // success, sql injection failed } else if err == nil { - dbt.Errorf("Sql injection successful with arg: %s", arg) + dbt.Errorf("sql injection successful with arg: %s", arg) } else { - dbt.Errorf("Error running query with arg: %s; err: %s", arg, err.Error()) + dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error()) } } } @@ -1705,14 +1705,14 @@ func TestUnixSocketAuthFail(t *testing.T) { // Get socket file from MySQL. err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) if err != nil { - t.Fatalf("Error on SELECT @@socket: %s", err.Error()) + t.Fatalf("error on SELECT @@socket: %s", err.Error()) } } t.Logf("socket: %s", socket) badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname) db, err := sql.Open("mysql", badDSN) if err != nil { - t.Fatalf("Error connecting: %s", err.Error()) + t.Fatalf("error connecting: %s", err.Error()) } defer db.Close() diff --git a/dsn.go b/dsn.go index 8ca73dc68..31fd530ee 100644 --- a/dsn.go +++ b/dsn.go @@ -19,10 +19,10 @@ import ( ) var ( - errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") - errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") - errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") - errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset") + errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name") + errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") ) // Config is a configuration parsed from a DSN string @@ -141,7 +141,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) { case "unix": cfg.Addr = "/tmp/mysql.sock" default: - return nil, errors.New("Default addr for network '" + cfg.Net + "' unknown") + return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") } } @@ -166,7 +166,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { var isBool bool cfg.AllowAllFiles, isBool = readBool(value) if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) + return errors.New("invalid bool value: " + value) } // Use cleartext authentication mode (MySQL 5.5.10+) @@ -174,7 +174,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { var isBool bool cfg.AllowCleartextPasswords, isBool = readBool(value) if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) + return errors.New("invalid bool value: " + value) } // Use old authentication mode (pre MySQL 4.1) @@ -182,7 +182,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { var isBool bool cfg.AllowOldPasswords, isBool = readBool(value) if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) + return errors.New("invalid bool value: " + value) } // Switch "rowsAffected" mode @@ -190,7 +190,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { var isBool bool cfg.ClientFoundRows, isBool = readBool(value) if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) + return errors.New("invalid bool value: " + value) } // Collation @@ -210,19 +210,19 @@ func parseDSNParams(cfg *Config, params string) (err error) { var isBool bool cfg.ColumnsWithAlias, isBool = readBool(value) if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) + return errors.New("invalid bool value: " + value) } // Compression case "compress": - return errors.New("Compression not implemented yet") + return errors.New("compression not implemented yet") // Enable client side placeholder substitution case "interpolateParams": var isBool bool cfg.InterpolateParams, isBool = readBool(value) if !isBool { - return fmt.Errorf("Invalid Bool value: %s", value) + return errors.New("invalid bool value: " + value) } // Time Location @@ -240,7 +240,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { var isBool bool cfg.ParseTime, isBool = readBool(value) if !isBool { - return errors.New("Invalid Bool value: " + value) + return errors.New("invalid bool value: " + value) } // I/O read Timeout @@ -255,7 +255,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { var isBool bool cfg.Strict, isBool = readBool(value) if !isBool { - return errors.New("Invalid Bool value: " + value) + return errors.New("invalid bool value: " + value) } // Dial Timeout @@ -273,7 +273,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { cfg.TLS = &tls.Config{} } } else if value, err := url.QueryUnescape(value); err != nil { - return fmt.Errorf("Invalid value for tls config name: %v", err) + return fmt.Errorf("invalid value for TLS config name: %v", err) } else { if strings.ToLower(value) == "skip-verify" { cfg.TLS = &tls.Config{InsecureSkipVerify: true} @@ -287,7 +287,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { cfg.TLS = tlsConfig } else { - return fmt.Errorf("Invalid value / unknown config name: %s", value) + return errors.New("invalid value / unknown config name: " + value) } } diff --git a/errors.go b/errors.go index 44cf30db6..1543a8054 100644 --- a/errors.go +++ b/errors.go @@ -19,20 +19,20 @@ import ( // Various errors the driver might return. Can change between driver versions. var ( - ErrInvalidConn = errors.New("Invalid Connection") - ErrMalformPkt = errors.New("Malformed Packet") - ErrNoTLS = errors.New("TLS encryption requested but server does not support TLS") - ErrOldPassword = errors.New("This user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") - ErrCleartextPassword = errors.New("This user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN.") - ErrUnknownPlugin = errors.New("The authentication plugin is not supported.") - ErrOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+") - ErrPktSync = errors.New("Commands out of sync. You can't run this command now") - ErrPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?") - ErrPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.") - ErrBusyBuffer = errors.New("Busy buffer") + ErrInvalidConn = errors.New("invalid connection") + ErrMalformPkt = errors.New("malformed packet") + ErrNoTLS = errors.New("TLS requested but server does not support TLS") + ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") + ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") + ErrUnknownPlugin = errors.New("this authentication plugin is not supported") + ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") + ErrPktSync = errors.New("commands out of sync. You can't run this command now") + ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") + ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") + ErrBusyBuffer = errors.New("busy buffer") ) -var errLog Logger = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) +var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) // Logger is used to log critical error messages. type Logger interface { diff --git a/infile.go b/infile.go index 9c898b705..e3e5e47d9 100644 --- a/infile.go +++ b/infile.go @@ -139,12 +139,12 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { } else if fileSize <= mc.maxPacketAllowed { data = make([]byte, 4+mc.maxWriteSize) } else { - err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed) + err = fmt.Errorf("local file '%s' too large: size: %d, max: %d", name, fileSize, mc.maxPacketAllowed) } } } } else { - err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name) + err = fmt.Errorf("local file '%s' is not registered", name) } } @@ -175,8 +175,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // read OK packet if err == nil { return mc.readResultOK() - } else { - mc.readPacket() } + + mc.readPacket() return err } diff --git a/packets.go b/packets.go index 88d990a2d..6ac1cccea 100644 --- a/packets.go +++ b/packets.go @@ -47,9 +47,8 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if data[3] != mc.sequence { if data[3] > mc.sequence { return nil, ErrPktSyncMul - } else { - return nil, ErrPktSync } + return nil, ErrPktSync } mc.sequence++ @@ -146,7 +145,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // protocol version [1 byte] if data[0] < minProtocolVersion { return nil, fmt.Errorf( - "Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required", + "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, ) @@ -539,13 +538,13 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // warning count [2 bytes] if !mc.strict { return nil - } else { - pos := 1 + n + m + 2 - if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - return mc.getWarnings() - } - return nil } + + pos := 1 + n + m + 2 + if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { + return mc.getWarnings() + } + return nil } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -564,7 +563,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if i == count { return columns, nil } - return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns)) + return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) } // Catalog @@ -742,13 +741,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // Warning count [16 bit uint] if !stmt.mc.strict { return columnCount, nil - } else { - // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { - return columnCount, stmt.mc.getWarnings() - } - return columnCount, nil } + + // Check for warnings count > 0, only available in MySQL > 4.1 + if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { + return columnCount, stmt.mc.getWarnings() + } + return columnCount, nil } return 0, err } @@ -810,7 +809,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( - "Arguments count mismatch (Got: %d Has: %d)", + "argument count mismatch (got: %d; has: %d)", len(args), stmt.paramCount, ) @@ -996,7 +995,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramValues = append(paramValues, val...) default: - return fmt.Errorf("Can't convert type: %T", arg) + return fmt.Errorf("can not convert type: %T", arg) } } @@ -1144,7 +1143,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { dstlen = 8 + 1 + decimals default: return fmt.Errorf( - "MySQL protocol error, illegal decimals value %d", + "protocol error, illegal decimals value %d", rows.columns[i].decimals, ) } @@ -1163,7 +1162,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { dstlen = 19 + 1 + decimals default: return fmt.Errorf( - "MySQL protocol error, illegal decimals value %d", + "protocol error, illegal decimals value %d", rows.columns[i].decimals, ) } @@ -1180,7 +1179,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType) + return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) } } diff --git a/utils.go b/utils.go index e267cce4e..d523b7ffd 100644 --- a/utils.go +++ b/utils.go @@ -23,10 +23,6 @@ var ( tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs ) -func init() { - tlsConfigRegister = make(map[string]*tls.Config) -} - // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // @@ -52,7 +48,11 @@ func init() { // func RegisterTLSConfig(key string, config *tls.Config) error { if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { - return fmt.Errorf("Key '%s' is reserved", key) + return fmt.Errorf("key '%s' is reserved", key) + } + + if tlsConfigRegister == nil { + tlsConfigRegister = make(map[string]*tls.Config) } tlsConfigRegister[key] = config @@ -61,7 +61,9 @@ func RegisterTLSConfig(key string, config *tls.Config) error { // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { - delete(tlsConfigRegister, key) + if tlsConfigRegister != nil { + delete(tlsConfigRegister, key) + } } // Returns the bool value of the input. @@ -258,7 +260,7 @@ func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { } t, err = time.Parse(timeFormat[:len(str)], str) default: - err = fmt.Errorf("Invalid Time-String: %s", str) + err = fmt.Errorf("invalid time string: %s", str) return } @@ -307,7 +309,7 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va loc, ), nil } - return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) + return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } // zeroDateTime is used in formatBinaryDateTime to avoid an allocation @@ -342,7 +344,7 @@ func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value switch len(src) { case 8, 12: default: - return nil, fmt.Errorf("Invalid TIME-packet length %d", len(src)) + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) } // +2 to enable negative time and 100+ hours dst = make([]byte, 0, length+2) @@ -376,7 +378,7 @@ func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value if length > 10 { t += "TIME" } - return nil, fmt.Errorf("illegal %s-packet length %d", t, len(src)) + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) } dst = make([]byte, 0, length) // start with the date @@ -642,7 +644,7 @@ func escapeBytesBackslash(buf, v []byte) []byte { pos += 2 default: buf[pos] = c - pos += 1 + pos++ } } @@ -687,7 +689,7 @@ func escapeStringBackslash(buf []byte, v string) []byte { pos += 2 default: buf[pos] = c - pos += 1 + pos++ } }