diff --git a/livesql/live.go b/livesql/live.go index 4055cde49..3696b2dec 100644 --- a/livesql/live.go +++ b/livesql/live.go @@ -76,7 +76,7 @@ func (t *dbTracker) processBinlog(update *update) { } } -func (t *dbTracker) registerDependency(ctx context.Context, table string, tester sqlgen.Tester, filter sqlgen.Filter) error { +func (t *dbTracker) registerDependency(ctx context.Context, schema *sqlgen.Schema, table string, tester sqlgen.Tester, filter sqlgen.Filter) error { r := &dbResource{ table: table, tester: tester, @@ -86,7 +86,7 @@ func (t *dbTracker) registerDependency(ctx context.Context, table string, tester t.remove(r) }) - proto, err := filterToProto(table, filter) + proto, err := filterToProto(schema, table, filter) if err != nil { return err } @@ -149,7 +149,7 @@ func (ldb *LiveDB) query(ctx context.Context, query *sqlgen.BaseSelectQuery) ([] // Register the dependency before we do the query to not miss any updates // between querying and registering. // Do not fail the query if this step fails. - _ = ldb.tracker.registerDependency(ctx, query.Table.Name, tester, query.Filter) + _ = ldb.tracker.registerDependency(ctx, ldb.Schema, query.Table.Name, tester, query.Filter) // Perform the query. // XXX: This will build the SQL string again... :( @@ -211,7 +211,7 @@ func (ldb *LiveDB) Close() error { } func (ldb *LiveDB) AddDependency(ctx context.Context, proto *thunderpb.SQLFilter) error { - table, filter, err := filterFromProto(proto) + table, filter, err := filterFromProto(ldb.Schema, proto) if err != nil { return err } @@ -221,7 +221,7 @@ func (ldb *LiveDB) AddDependency(ctx context.Context, proto *thunderpb.SQLFilter return err } - if err := ldb.tracker.registerDependency(ctx, table, tester, filter); err != nil { + if err := ldb.tracker.registerDependency(ctx, ldb.Schema, table, tester, filter); err != nil { return err } return nil diff --git a/livesql/marshal.go b/livesql/marshal.go index eb6769e1c..914b6a7df 100644 --- a/livesql/marshal.go +++ b/livesql/marshal.go @@ -2,10 +2,12 @@ package livesql import ( "database/sql/driver" + "errors" "fmt" "reflect" "time" + "github.com/samsarahq/thunder/internal/fields" "github.com/samsarahq/thunder/sqlgen" "github.com/samsarahq/thunder/thunderpb" ) @@ -62,26 +64,87 @@ func fieldToValue(field *thunderpb.Field) (driver.Value, error) { } } -func filterToProto(table string, filter sqlgen.Filter) (*thunderpb.SQLFilter, error) { +// filterToProto takes a sqlgen.Filter, runs Valuer on each filter value, and returns a thunderpb.SQLFilter. +func filterToProto(schema *sqlgen.Schema, tableName string, filter sqlgen.Filter) (*thunderpb.SQLFilter, error) { + table, ok := schema.ByName[tableName] + if !ok { + return nil, fmt.Errorf("unknown table: %s", tableName) + } + + if filter == nil { + return &thunderpb.SQLFilter{Table: tableName}, nil + } + fields := make(map[string]*thunderpb.Field, len(filter)) for col, val := range filter { + column, ok := table.ColumnsByName[col] + if !ok { + return nil, fmt.Errorf("unknown column %s", col) + } + + val, err := column.Descriptor.Valuer(reflect.ValueOf(val)).Value() + if err != nil { + return nil, err + } + field, err := valueToField(val) if err != nil { return nil, err } fields[col] = field } - return &thunderpb.SQLFilter{Table: table, Fields: fields}, nil + return &thunderpb.SQLFilter{Table: tableName, Fields: fields}, nil } -func filterFromProto(proto *thunderpb.SQLFilter) (string, sqlgen.Filter, error) { +// filterFromProto takes a thunderpb.SQLFilter, runs Scanner on each field value, and returns a sqlgen.Filter. +func filterFromProto(schema *sqlgen.Schema, proto *thunderpb.SQLFilter) (string, sqlgen.Filter, error) { + table, ok := schema.ByName[proto.Table] + if !ok { + return "", nil, fmt.Errorf("unknown table: %s", proto.Table) + } + + scanners := table.Scanners.Get().([]interface{}) + defer table.Scanners.Put(scanners) + filter := make(sqlgen.Filter, len(proto.Fields)) for col, field := range proto.Fields { val, err := fieldToValue(field) if err != nil { return "", nil, err } - filter[col] = val + + column, ok := table.ColumnsByName[col] + if !ok { + return "", nil, fmt.Errorf("unknown column %s", col) + } + + if !column.Descriptor.Ptr && val == nil { + return "", nil, errors.New("cannot unmarshal nil into non-pointer type") + } + + scanner := scanners[column.Order].(*fields.Scanner) + + // target is always a pointer. + var target, ptrptr reflect.Value + if column.Descriptor.Ptr { + // We need to hold onto this pointer-pointer in order to make the value addressable. + ptrptr = reflect.New(reflect.PtrTo(column.Descriptor.Type)) + target = ptrptr.Elem() + } else { + target = reflect.New(column.Descriptor.Type) + } + scanner.Target(target) + + if err := scanner.Scan(val); err != nil { + return "", nil, err + } + + if column.Descriptor.Ptr { + filter[col] = target.Interface() + } else { + // Dereference pointer if column type is not a pointer. + filter[col] = target.Elem().Interface() + } } return proto.Table, filter, nil } diff --git a/livesql/marshal_test.go b/livesql/marshal_test.go new file mode 100644 index 000000000..2e47c7fa7 --- /dev/null +++ b/livesql/marshal_test.go @@ -0,0 +1,107 @@ +package livesql + +import ( + "testing" + + "github.com/samsarahq/thunder/internal/testfixtures" + "github.com/samsarahq/thunder/sqlgen" + "github.com/stretchr/testify/assert" +) + +func TestMarshal(t *testing.T) { + type user struct { + Id int64 `sql:",primary"` + Name *string + Uuid testfixtures.CustomType + Mood *testfixtures.CustomType + } + + schema := sqlgen.NewSchema() + schema.MustRegisterType("users", sqlgen.AutoIncrement, user{}) + + one := int64(1) + foo := "foo" + + cases := []struct { + name string + filter sqlgen.Filter + unmarshaled sqlgen.Filter + err bool + }{ + { + name: "nil", + filter: nil, + unmarshaled: sqlgen.Filter{}, + }, + { + name: "empty", + filter: sqlgen.Filter{}, + unmarshaled: sqlgen.Filter{}, + }, + { + name: "uuid", + filter: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")}, + unmarshaled: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")}, + }, + { + name: "uuid from bytes", + filter: sqlgen.Filter{"uuid": []byte("foo")}, + unmarshaled: sqlgen.Filter{"uuid": testfixtures.CustomTypeFromString("foo")}, + }, + { + name: "nil uuid", + filter: sqlgen.Filter{"mood": nil}, + unmarshaled: sqlgen.Filter{"mood": (*testfixtures.CustomType)(nil)}, + }, + { + name: "id", + filter: sqlgen.Filter{"id": int64(1)}, + unmarshaled: sqlgen.Filter{"id": int64(1)}, + }, + { + name: "id int32 to int64", + filter: sqlgen.Filter{"id": int32(1)}, + unmarshaled: sqlgen.Filter{"id": int64(1)}, + }, + { + name: "id int64 ptr to int64", + filter: sqlgen.Filter{"id": &one}, + unmarshaled: sqlgen.Filter{"id": int64(1)}, + }, + { + name: "string to string ptr", + filter: sqlgen.Filter{"name": "foo"}, + unmarshaled: sqlgen.Filter{"name": &foo}}, + { + name: "nil to string ptr", + filter: sqlgen.Filter{"name": nil}, + unmarshaled: sqlgen.Filter{"name": (*string)(nil)}, + }, + { + name: "nil for int64", + filter: sqlgen.Filter{"id": nil}, + err: true, + }, + { + name: "string for int64", + filter: sqlgen.Filter{"id": ""}, + err: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + proto, err := filterToProto(schema, "users", c.filter) + assert.NoError(t, err) + + table, filter, err := filterFromProto(schema, proto) + if c.err { + assert.NotNil(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, "users", table) + assert.Equal(t, c.unmarshaled, filter) + } + }) + } +}