Skip to content

Commit 78d399c

Browse files
pushraxjulienschmidt
authored andcommitted
Implement NamedValueChecker for mysqlConn (go-sql-driver#690)
* Also add conversions for additional types in ConvertValue ref golang/go@d7c0de9
1 parent 0aa39ff commit 78d399c

5 files changed

+156
-7
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Jian Zhen <zhenjl at gmail.com>
4040
Joshua Prunier <joshua.prunier at gmail.com>
4141
Julien Lefevre <julien.lefevr at gmail.com>
4242
Julien Schmidt <go-sql-driver at julienschmidt.com>
43+
Justin Li <jli at j-li.net>
4344
Justin Nuß <nuss.justin at gmail.com>
4445
Kamil Dziedzic <kamil at klecza.pl>
4546
Kevin Malachowski <kevin at chowski.com>

connection_go18.go

+5
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,8 @@ func (mc *mysqlConn) startWatcher() {
195195
}
196196
}()
197197
}
198+
199+
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
200+
nv.Value, err = converter{}.ConvertValue(nv.Value)
201+
return
202+
}

connection_go18_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
2+
//
3+
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
4+
//
5+
// This Source Code Form is subject to the terms of the Mozilla Public
6+
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
7+
// You can obtain one at http://mozilla.org/MPL/2.0/.
8+
9+
// +build go1.8
10+
11+
package mysql
12+
13+
import (
14+
"database/sql/driver"
15+
"testing"
16+
)
17+
18+
func TestCheckNamedValue(t *testing.T) {
19+
value := driver.NamedValue{Value: ^uint64(0)}
20+
x := &mysqlConn{}
21+
err := x.CheckNamedValue(&value)
22+
23+
if err != nil {
24+
t.Fatal("uint64 high-bit not convertible", err)
25+
}
26+
27+
if value.Value != "18446744073709551615" {
28+
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
29+
}
30+
}

statement.go

+8
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
157157
return int64(u64), nil
158158
case reflect.Float32, reflect.Float64:
159159
return rv.Float(), nil
160+
case reflect.Bool:
161+
return rv.Bool(), nil
162+
case reflect.Slice:
163+
ek := rv.Type().Elem().Kind()
164+
if ek == reflect.Uint8 {
165+
return rv.Bytes(), nil
166+
}
167+
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
160168
case reflect.String:
161169
return rv.String(), nil
162170
}

statement_test.go

+112-7
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,119 @@
88

99
package mysql
1010

11-
import "testing"
11+
import (
12+
"bytes"
13+
"testing"
14+
)
1215

13-
type customString string
16+
func TestConvertDerivedString(t *testing.T) {
17+
type derived string
1418

15-
func TestConvertValueCustomTypes(t *testing.T) {
16-
var cstr customString = "string"
17-
c := converter{}
18-
if _, err := c.ConvertValue(cstr); err != nil {
19-
t.Errorf("custom string type should be valid")
19+
output, err := converter{}.ConvertValue(derived("value"))
20+
if err != nil {
21+
t.Fatal("Derived string type not convertible", err)
22+
}
23+
24+
if output != "value" {
25+
t.Fatalf("Derived string type not converted, got %#v %T", output, output)
26+
}
27+
}
28+
29+
func TestConvertDerivedByteSlice(t *testing.T) {
30+
type derived []uint8
31+
32+
output, err := converter{}.ConvertValue(derived("value"))
33+
if err != nil {
34+
t.Fatal("Byte slice not convertible", err)
35+
}
36+
37+
if bytes.Compare(output.([]byte), []byte("value")) != 0 {
38+
t.Fatalf("Byte slice not converted, got %#v %T", output, output)
39+
}
40+
}
41+
42+
func TestConvertDerivedUnsupportedSlice(t *testing.T) {
43+
type derived []int
44+
45+
_, err := converter{}.ConvertValue(derived{1})
46+
if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
47+
t.Fatal("Unexpected error", err)
48+
}
49+
}
50+
51+
func TestConvertDerivedBool(t *testing.T) {
52+
type derived bool
53+
54+
output, err := converter{}.ConvertValue(derived(true))
55+
if err != nil {
56+
t.Fatal("Derived bool type not convertible", err)
57+
}
58+
59+
if output != true {
60+
t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
61+
}
62+
}
63+
64+
func TestConvertPointer(t *testing.T) {
65+
str := "value"
66+
67+
output, err := converter{}.ConvertValue(&str)
68+
if err != nil {
69+
t.Fatal("Pointer type not convertible", err)
70+
}
71+
72+
if output != "value" {
73+
t.Fatalf("Pointer type not converted, got %#v %T", output, output)
74+
}
75+
}
76+
77+
func TestConvertSignedIntegers(t *testing.T) {
78+
values := []interface{}{
79+
int8(-42),
80+
int16(-42),
81+
int32(-42),
82+
int64(-42),
83+
int(-42),
84+
}
85+
86+
for _, value := range values {
87+
output, err := converter{}.ConvertValue(value)
88+
if err != nil {
89+
t.Fatalf("%T type not convertible %s", value, err)
90+
}
91+
92+
if output != int64(-42) {
93+
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
94+
}
95+
}
96+
}
97+
98+
func TestConvertUnsignedIntegers(t *testing.T) {
99+
values := []interface{}{
100+
uint8(42),
101+
uint16(42),
102+
uint32(42),
103+
uint64(42),
104+
uint(42),
105+
}
106+
107+
for _, value := range values {
108+
output, err := converter{}.ConvertValue(value)
109+
if err != nil {
110+
t.Fatalf("%T type not convertible %s", value, err)
111+
}
112+
113+
if output != int64(42) {
114+
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
115+
}
116+
}
117+
118+
output, err := converter{}.ConvertValue(^uint64(0))
119+
if err != nil {
120+
t.Fatal("uint64 high-bit not convertible", err)
121+
}
122+
123+
if output != "18446744073709551615" {
124+
t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output)
20125
}
21126
}

0 commit comments

Comments
 (0)