diff --git a/types/types.go b/types/types.go index 2ae9dfb4..2c90b4cf 100644 --- a/types/types.go +++ b/types/types.go @@ -57,9 +57,14 @@ func (g *GzippedText) Scan(src interface{}) error { // implements `Unmarshal`, which unmarshals the json within to an interface{} type JSONText json.RawMessage -// MarshalJSON returns j as the JSON encoding of j. -func (j JSONText) MarshalJSON() ([]byte, error) { - return j, nil +var _EMPTY_JSON = JSONText("{}") + +// MarshalJSON returns the *j as the JSON encoding of j. +func (j *JSONText) MarshalJSON() ([]byte, error) { + if len(*j) == 0 { + *j = _EMPTY_JSON + } + return *j, nil } // UnmarshalJSON sets *j to a copy of data @@ -69,7 +74,6 @@ func (j *JSONText) UnmarshalJSON(data []byte) error { } *j = append((*j)[0:0], data...) return nil - } // Value returns j as a value. This does a validating unmarshal into another @@ -86,11 +90,17 @@ func (j JSONText) Value() (driver.Value, error) { // Scan stores the src in *j. No validation is done. func (j *JSONText) Scan(src interface{}) error { var source []byte - switch src.(type) { + switch t := src.(type) { case string: - source = []byte(src.(string)) + source = []byte(t) case []byte: - source = src.([]byte) + if len(t) == 0 { + source = _EMPTY_JSON + } else { + source = t + } + case nil: + *j = _EMPTY_JSON default: return errors.New("Incompatible type for JSONText") } @@ -100,14 +110,43 @@ func (j *JSONText) Scan(src interface{}) error { // Unmarshal unmarshal's the json in j to v, as in json.Unmarshal. func (j *JSONText) Unmarshal(v interface{}) error { + if len(*j) == 0 { + *j = _EMPTY_JSON + } return json.Unmarshal([]byte(*j), v) } -// Pretty printing for JSONText types +// String supports pretty printing for JSONText types. func (j JSONText) String() string { return string(j) } +// NullJSONText represents a JSONText that may be null. +// NullJSONText implements the scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullJSONText struct { + JSONText + Valid bool // Valid is true if JSONText is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullJSONText) Scan(value interface{}) error { + if value == nil { + n.JSONText, n.Valid = _EMPTY_JSON, false + return nil + } + n.Valid = true + return n.JSONText.Scan(value) +} + +// Value implements the driver Valuer interface. +func (n NullJSONText) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.JSONText.Value() +} + // BitBool is an implementation of a bool for the MySQL type BIT(1). // This type allows you to avoid wasting an entire byte for MySQL's boolean type TINYINT. type BitBool bool diff --git a/types/types_test.go b/types/types_test.go index c682cfd3..cbe200c3 100644 --- a/types/types_test.go +++ b/types/types_test.go @@ -39,6 +39,59 @@ func TestJSONText(t *testing.T) { if err == nil { t.Errorf("Was expecting invalid json to fail!") } + + j = JSONText("") + v, err = j.Value() + if err != nil { + t.Errorf("Was not expecting an error") + } + + err = (&j).Scan(v) + if err != nil { + t.Errorf("Was not expecting an error") + } + + j = JSONText(nil) + v, err = j.Value() + if err != nil { + t.Errorf("Was not expecting an error") + } + + err = (&j).Scan(v) + if err != nil { + t.Errorf("Was not expecting an error") + } +} + +func TestNullJSONText(t *testing.T) { + j := NullJSONText{} + err := j.Scan(`{"foo": 1, "bar": 2}`) + if err != nil { + t.Errorf("Was not expecting an error") + } + v, err := j.Value() + if err != nil { + t.Errorf("Was not expecting an error") + } + err = (&j).Scan(v) + if err != nil { + t.Errorf("Was not expecting an error") + } + m := map[string]interface{}{} + j.Unmarshal(&m) + + if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 { + t.Errorf("Expected valid json but got some garbage instead? %#v", m) + } + + j = NullJSONText{} + err = j.Scan(nil) + if err != nil { + t.Errorf("Was not expecting an error") + } + if j.Valid != false { + t.Errorf("Expected valid to be false, but got true") + } } func TestBitBool(t *testing.T) {