Skip to content

Commit 3d8a029

Browse files
authored
stmt: add json.RawMessage for converter and prepared statement (go-sql-driver#1059)
Following go-sql-driver#1058, in order for the driver.Value to get as a json.RawMessage, the converter should accept it as a valid value, and handle it as bytes in case where interpolation is disabled
1 parent 5a8a207 commit 3d8a029

File tree

5 files changed

+53
-4
lines changed

5 files changed

+53
-4
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Alex Snast <alexsn at fb.com>
1717
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
1818
Andrew Reid <andrew.reid at tixtrack.com>
1919
Arne Hormann <arnehormann at gmail.com>
20+
Ariel Mashraki <ariel at mashraki.co.il>
2021
Asta Xie <xiemengjun at gmail.com>
2122
Bulat Gaifullin <gaifullinbf at gmail.com>
2223
Carlos Nieto <jose.carlos at menteslibres.net>

driver_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"crypto/tls"
1515
"database/sql"
1616
"database/sql/driver"
17+
"encoding/json"
1718
"fmt"
1819
"io"
1920
"io/ioutil"
@@ -559,6 +560,29 @@ func TestRawBytes(t *testing.T) {
559560
})
560561
}
561562

563+
func TestRawMessage(t *testing.T) {
564+
runTests(t, dsn, func(dbt *DBTest) {
565+
v1 := json.RawMessage("{}")
566+
v2 := json.RawMessage("[]")
567+
rows := dbt.mustQuery("SELECT ?, ?", v1, v2)
568+
defer rows.Close()
569+
if rows.Next() {
570+
var o1, o2 json.RawMessage
571+
if err := rows.Scan(&o1, &o2); err != nil {
572+
dbt.Errorf("Got error: %v", err)
573+
}
574+
if !bytes.Equal(v1, o1) {
575+
dbt.Errorf("expected %v, got %v", v1, o1)
576+
}
577+
if !bytes.Equal(v2, o2) {
578+
dbt.Errorf("expected %v, got %v", v2, o2)
579+
}
580+
} else {
581+
dbt.Errorf("no data")
582+
}
583+
})
584+
}
585+
562586
type testValuer struct {
563587
value string
564588
}

packets.go

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"crypto/tls"
1414
"database/sql/driver"
1515
"encoding/binary"
16+
"encoding/json"
1617
"errors"
1718
"fmt"
1819
"io"
@@ -1003,6 +1004,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10031004
continue
10041005
}
10051006

1007+
if v, ok := arg.(json.RawMessage); ok {
1008+
arg = []byte(v)
1009+
}
10061010
// cache types and values
10071011
switch v := arg.(type) {
10081012
case int64:

statement.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"database/sql/driver"
13+
"encoding/json"
1314
"fmt"
1415
"io"
1516
"reflect"
@@ -129,6 +130,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
129130
return rows, err
130131
}
131132

133+
var jsonType = reflect.TypeOf(json.RawMessage{})
134+
132135
type converter struct{}
133136

134137
// ConvertValue mirrors the reference/default converter in database/sql/driver
@@ -151,7 +154,6 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
151154
}
152155
return sv, nil
153156
}
154-
155157
rv := reflect.ValueOf(v)
156158
switch rv.Kind() {
157159
case reflect.Ptr:
@@ -170,11 +172,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
170172
case reflect.Bool:
171173
return rv.Bool(), nil
172174
case reflect.Slice:
173-
ek := rv.Type().Elem().Kind()
174-
if ek == reflect.Uint8 {
175+
switch t := rv.Type(); {
176+
case t == jsonType:
177+
return v, nil
178+
case t.Elem().Kind() == reflect.Uint8:
175179
return rv.Bytes(), nil
180+
default:
181+
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
176182
}
177-
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
178183
case reflect.String:
179184
return rv.String(), nil
180185
}

statement_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package mysql
1010

1111
import (
1212
"bytes"
13+
"encoding/json"
1314
"testing"
1415
)
1516

@@ -124,3 +125,17 @@ func TestConvertUnsignedIntegers(t *testing.T) {
124125
t.Fatalf("uint64 high-bit converted, got %#v %T", output, output)
125126
}
126127
}
128+
129+
func TestConvertJSON(t *testing.T) {
130+
raw := json.RawMessage("{}")
131+
132+
out, err := converter{}.ConvertValue(raw)
133+
134+
if err != nil {
135+
t.Fatal("json.RawMessage was failed in convert", err)
136+
}
137+
138+
if _, ok := out.(json.RawMessage); !ok {
139+
t.Fatalf("json.RawMessage converted, got %#v %T", out, out)
140+
}
141+
}

0 commit comments

Comments
 (0)