Skip to content

Commit e288006

Browse files
committed
Add TLS-Support
Fixes go-sql-driver#25
1 parent eb5cb81 commit e288006

7 files changed

+159
-61
lines changed

connection.go

+12-13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package mysql
1111

1212
import (
13+
"crypto/tls"
1314
"database/sql/driver"
1415
"errors"
1516
"net"
@@ -35,13 +36,15 @@ type mysqlConn struct {
3536
}
3637

3738
type config struct {
38-
user string
39-
passwd string
40-
net string
41-
addr string
42-
dbname string
43-
params map[string]string
44-
loc *time.Location
39+
user string
40+
passwd string
41+
net string
42+
addr string
43+
dbname string
44+
params map[string]string
45+
loc *time.Location
46+
timeout time.Duration
47+
tls *tls.Config
4548
}
4649

4750
// Handles parameters set in DSN
@@ -63,7 +66,7 @@ func (mc *mysqlConn) handleParams() (err error) {
6366
}
6467

6568
// handled elsewhere
66-
case "timeout", "allowAllFiles", "loc", "clientFoundRows":
69+
case "allowAllFiles", "clientFoundRows":
6770
continue
6871

6972
// time.Time parsing
@@ -74,14 +77,10 @@ func (mc *mysqlConn) handleParams() (err error) {
7477
case "strict":
7578
mc.strict = readBool(val)
7679

77-
// TLS-Encryption
78-
case "tls":
79-
err = errors.New("TLS-Encryption not implemented yet")
80-
return
81-
8280
// Compression
8381
case "compress":
8482
err = errors.New("Compression not implemented yet")
83+
return
8584

8685
// System Vars
8786
default:

driver.go

+2-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"database/sql"
1313
"database/sql/driver"
1414
"net"
15-
"time"
1615
)
1716

1817
type mysqlDriver struct{}
@@ -34,11 +33,9 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
3433
}
3534

3635
// Connect to Server
37-
if _, ok := mc.cfg.params["timeout"]; ok { // with timeout
38-
var timeout time.Duration
39-
timeout, err = time.ParseDuration(mc.cfg.params["timeout"])
36+
if mc.cfg.timeout > 0 { // with timeout
4037
if err == nil {
41-
mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, timeout)
38+
mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, mc.cfg.timeout)
4239
}
4340
} else { // no timeout
4441
mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)

driver_test.go

+47-11
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,42 @@ func TestStrict(t *testing.T) {
807807
})
808808
}
809809

810+
func TestTLS(t *testing.T) {
811+
runTests(t, "TestTLS", dsn+"&tls=skip-verify", func(dbt *DBTest) {
812+
/* TODO: GO 1.1 API */
813+
/*if err := dbt.db.Ping(); err != nil {
814+
if err == errNoTLS {
815+
dbt.Skip("Server does not support TLS. Skipping TestTLS")
816+
} else {
817+
dbt.Fatalf("Error on Ping: %s", err.Error())
818+
}
819+
}*/
820+
821+
/* GO 1.0 API */
822+
if _, err := dbt.db.Exec("DO 1"); err != nil {
823+
if err == errNoTLS {
824+
dbt.Log("Server does not support TLS. Skipping TestTLS")
825+
return
826+
} else {
827+
dbt.Fatalf("Error on Ping: %s", err.Error())
828+
}
829+
}
830+
831+
rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'")
832+
833+
var variable, value *sql.RawBytes
834+
for rows.Next() {
835+
if err := rows.Scan(&variable, &value); err != nil {
836+
dbt.Fatal(err.Error())
837+
}
838+
839+
if value == nil {
840+
dbt.Fatal("No Cipher")
841+
}
842+
}
843+
})
844+
}
845+
810846
// Special cases
811847

812848
func TestRowsClose(t *testing.T) {
@@ -1040,41 +1076,41 @@ func TestFoundRows(t *testing.T) {
10401076
runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) {
10411077
dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
10421078
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
1043-
1079+
10441080
res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
10451081
count, err := res.RowsAffected()
10461082
if err != nil {
1047-
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1048-
}
1083+
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1084+
}
10491085
if count != 2 {
10501086
dbt.Fatalf("Expected 2 affected rows, got %d", count)
10511087
}
10521088
res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
10531089
count, err = res.RowsAffected()
10541090
if err != nil {
1055-
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1056-
}
1091+
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1092+
}
10571093
if count != 2 {
10581094
dbt.Fatalf("Expected 2 affected rows, got %d", count)
10591095
}
10601096
})
1061-
runTests(t, "TestFoundRows2", dsn + "&clientFoundRows=true", func(dbt *DBTest) {
1097+
runTests(t, "TestFoundRows2", dsn+"&clientFoundRows=true", func(dbt *DBTest) {
10621098
dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
10631099
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
1064-
1100+
10651101
res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
10661102
count, err := res.RowsAffected()
10671103
if err != nil {
1068-
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1069-
}
1104+
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1105+
}
10701106
if count != 2 {
10711107
dbt.Fatalf("Expected 2 matched rows, got %d", count)
10721108
}
10731109
res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
10741110
count, err = res.RowsAffected()
10751111
if err != nil {
1076-
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1077-
}
1112+
dbt.Fatalf("res.RowsAffected() returned error: %v", err)
1113+
}
10781114
if count != 3 {
10791115
dbt.Fatalf("Expected 3 matched rows, got %d", count)
10801116
}

errors.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ import (
1818

1919
var (
2020
errMalformPkt = errors.New("Malformed Packet")
21+
errNoTLS = errors.New("TLS encryption requested but server does not support TLS")
22+
errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
23+
errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
2124
errPktSync = errors.New("Commands out of sync. You can't run this command now")
2225
errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
23-
errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
2426
errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
2527
)
2628

packets.go

+49-20
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ package mysql
1111

1212
import (
1313
"bytes"
14+
"crypto/tls"
1415
"database/sql/driver"
1516
"encoding/binary"
16-
"errors"
1717
"fmt"
1818
"io"
1919
"math"
@@ -167,7 +167,10 @@ func (mc *mysqlConn) readInitPacket() (err error) {
167167
// capability flags (lower 2 bytes) [2 bytes]
168168
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
169169
if mc.flags&clientProtocol41 == 0 {
170-
err = errors.New("MySQL-Server does not support required Protocol 41+")
170+
err = errOldProtocol
171+
}
172+
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
173+
return errNoTLS
171174
}
172175
pos += 2
173176

@@ -205,19 +208,22 @@ func (mc *mysqlConn) readInitPacket() (err error) {
205208
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
206209
func (mc *mysqlConn) writeAuthPacket() error {
207210
// Adjust client flags based on server support
208-
clientFlags := uint32(
209-
clientProtocol41 |
210-
clientSecureConn |
211-
clientLongPassword |
212-
clientTransactions |
213-
clientLocalFiles,
214-
)
215-
if mc.flags&clientLongFlag > 0 {
216-
clientFlags |= uint32(clientLongFlag)
217-
}
211+
clientFlags := clientProtocol41 |
212+
clientSecureConn |
213+
clientLongPassword |
214+
clientTransactions |
215+
clientLocalFiles |
216+
mc.flags&clientLongFlag
217+
218218
if _, ok := mc.cfg.params["clientFoundRows"]; ok {
219-
clientFlags |= uint32(clientFoundRows)
219+
clientFlags |= clientFoundRows
220220
}
221+
222+
// To enable TLS / SSL
223+
if mc.cfg.tls != nil {
224+
clientFlags |= clientSSL
225+
}
226+
221227
// User Password
222228
scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
223229
mc.cipher = nil
@@ -226,19 +232,13 @@ func (mc *mysqlConn) writeAuthPacket() error {
226232

227233
// To specify a db name
228234
if len(mc.cfg.dbname) > 0 {
229-
clientFlags |= uint32(clientConnectWithDB)
235+
clientFlags |= clientConnectWithDB
230236
pktLen += len(mc.cfg.dbname) + 1
231237
}
232238

233239
// Calculate packet length and make buffer with that size
234240
data := make([]byte, pktLen+4)
235241

236-
// Add the packet header [24bit length + 1 byte sequence]
237-
data[0] = byte(pktLen)
238-
data[1] = byte(pktLen >> 8)
239-
data[2] = byte(pktLen >> 16)
240-
data[3] = mc.sequence
241-
242242
// ClientFlags [32 bit]
243243
data[4] = byte(clientFlags)
244244
data[5] = byte(clientFlags >> 8)
@@ -254,6 +254,35 @@ func (mc *mysqlConn) writeAuthPacket() error {
254254
// Charset [1 byte]
255255
data[12] = mc.charset
256256

257+
// SSL Connection Request Packet
258+
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest
259+
if mc.cfg.tls != nil {
260+
// Packet header [24bit length + 1 byte sequence]
261+
data[0] = byte((4 + 4 + 1 + 23))
262+
data[1] = byte((4 + 4 + 1 + 23) >> 8)
263+
data[2] = byte((4 + 4 + 1 + 23) >> 16)
264+
data[3] = mc.sequence
265+
266+
// Send TLS / SSL request packet
267+
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
268+
return err
269+
}
270+
271+
// Switch to TLS
272+
tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
273+
if err := tlsConn.Handshake(); err != nil {
274+
return err
275+
}
276+
mc.netConn = tlsConn
277+
mc.buf.rd = tlsConn
278+
}
279+
280+
// Add the packet header [24bit length + 1 byte sequence]
281+
data[0] = byte(pktLen)
282+
data[1] = byte(pktLen >> 8)
283+
data[2] = byte(pktLen >> 16)
284+
data[3] = mc.sequence
285+
257286
// Filler [23 bytes] (all 0x00)
258287
pos := 13 + 23
259288

utils.go

+34-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package mysql
1111

1212
import (
1313
"crypto/sha1"
14+
"crypto/tls"
1415
"database/sql/driver"
1516
"encoding/binary"
1617
"fmt"
@@ -119,7 +120,35 @@ func parseDSN(dsn string) (cfg *config, err error) {
119120
if len(param) != 2 {
120121
continue
121122
}
122-
cfg.params[param[0]] = param[1]
123+
124+
// cfg params
125+
switch value := param[1]; param[0] {
126+
127+
// Time Location
128+
case "loc":
129+
cfg.loc, err = time.LoadLocation(value)
130+
if err != nil {
131+
return
132+
}
133+
134+
// Dial Timeout
135+
case "timeout":
136+
cfg.timeout, err = time.ParseDuration(value)
137+
if err != nil {
138+
return
139+
}
140+
141+
// TLS-Encryption
142+
case "tls":
143+
if readBool(value) {
144+
cfg.tls = &tls.Config{}
145+
} else if strings.ToLower(value) == "skip-verify" {
146+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
147+
}
148+
149+
default:
150+
cfg.params[param[0]] = value
151+
}
123152
}
124153
}
125154
}
@@ -134,7 +163,10 @@ func parseDSN(dsn string) (cfg *config, err error) {
134163
cfg.addr = "127.0.0.1:3306"
135164
}
136165

137-
cfg.loc, err = time.LoadLocation(cfg.params["loc"])
166+
// Set default location if not set
167+
if cfg.loc == nil {
168+
cfg.loc = time.UTC
169+
}
138170

139171
return
140172
}

utils_test.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) {
2121
out string
2222
loc *time.Location
2323
}{
24-
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p}", time.UTC},
25-
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
26-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
27-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p}", time.UTC},
28-
{"user:password@/dbname?loc=UTC", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[loc:UTC] loc:%p}", time.UTC},
29-
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[loc:Local] loc:%p}", time.Local},
30-
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p}", time.UTC},
31-
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
32-
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
24+
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil>}", time.UTC},
25+
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
26+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
27+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
28+
{"user:password@/dbname?loc=UTC&timeout=30s", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil>}", time.UTC},
29+
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.Local},
30+
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
31+
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
32+
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
3333
}
3434

3535
var cfg *config
@@ -42,6 +42,9 @@ func TestDSNParser(t *testing.T) {
4242
t.Error(err.Error())
4343
}
4444

45+
// pointer not static
46+
cfg.tls = nil
47+
4548
res = fmt.Sprintf("%+v", cfg)
4649
if res != fmt.Sprintf(tst.out, tst.loc) {
4750
t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))

0 commit comments

Comments
 (0)