Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/ecodeclub/eorm into dev
Browse files Browse the repository at this point in the history
# Conflicts:
#	.CHANGELOG.md
  • Loading branch information
heroyf committed Mar 25, 2023
2 parents 6d7ce93 + 8996c2e commit 2679e84
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
- [eorm: 分库分表: Merger抽象与批量查询实现](https://github.com/ecodeclub/eorm/pull/160)
- [eorm: 增强的 ShardingAlgorithm 设计与实现](https://github.com/ecodeclub/eorm/pull/161)
- [eorm: 分库分表: Merger排序实现](https://github.com/ecodeclub/eorm/pull/166)
- [eorm: 分库分表: Not支持](https://github.com/ecodeclub/eorm/pull/174)
- [eorm: BasicTypeValue重命名](https://github.com/ecodeclub/eorm/pull/177)
- [eorm: 分库分表: hash、shadow_hash算法不符合预期](https://github.com/ecodeclub/eorm/pull/174)

## v0.0.1:
Expand Down
12 changes: 6 additions & 6 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ func (q Query) string() string {

func TestQuerier_Get(t *testing.T) {
t.Run("unsafe", func(t *testing.T) {
testQuerierGet(t, valuer.BasicTypeCreator{Creator: valuer.NewUnsafeValue})
testQuerierGet(t, valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue})
})

t.Run("reflect", func(t *testing.T) {
testQuerierGet(t, valuer.BasicTypeCreator{Creator: valuer.NewReflectValue})
testQuerierGet(t, valuer.PrimitiveCreator{Creator: valuer.NewReflectValue})
})
}

func testQuerierGet(t *testing.T, creator valuer.BasicTypeCreator) {
func testQuerierGet(t *testing.T, creator valuer.PrimitiveCreator) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -150,14 +150,14 @@ func testQuerierGet(t *testing.T, creator valuer.BasicTypeCreator) {

func TestQuerierGetMulti(t *testing.T) {
t.Run("unsafe", func(t *testing.T) {
testQuerier_GetMulti(t, valuer.BasicTypeCreator{Creator: valuer.NewUnsafeValue})
testQuerier_GetMulti(t, valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue})
})
t.Run("reflect", func(t *testing.T) {
testQuerier_GetMulti(t, valuer.BasicTypeCreator{Creator: valuer.NewReflectValue})
testQuerier_GetMulti(t, valuer.PrimitiveCreator{Creator: valuer.NewReflectValue})
})
}

func testQuerier_GetMulti(t *testing.T, creator valuer.BasicTypeCreator) {
func testQuerier_GetMulti(t *testing.T, creator valuer.PrimitiveCreator) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatal(err)
Expand Down
6 changes: 3 additions & 3 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
type core struct {
metaRegistry model.MetaRegistry
dialect dialect.Dialect
valCreator valuer.BasicTypeCreator
valCreator valuer.PrimitiveCreator
ms []Middleware
}

Expand All @@ -49,7 +49,7 @@ func getHandler[T any](ctx context.Context, sess Session, c core, qc *QueryConte
meta, _ = c.metaRegistry.Get(tp)
}

val := c.valCreator.NewBasicTypeValue(tp, meta)
val := c.valCreator.NewPrimitiveValue(tp, meta)
if err = val.SetColumns(rows); err != nil {
return &QueryResult{Err: err}
}
Expand Down Expand Up @@ -85,7 +85,7 @@ func getMultiHandler[T any](ctx context.Context, sess Session, c core, qc *Query
}
for rows.Next() {
tp := new(T)
val := c.valCreator.NewBasicTypeValue(tp, meta)
val := c.valCreator.NewPrimitiveValue(tp, meta)
if err = val.SetColumns(rows); err != nil {
return &QueryResult{Err: err}
}
Expand Down
4 changes: 2 additions & 2 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func DBWithMiddlewares(ms ...Middleware) DBOption {

func UseReflection() DBOption {
return func(db *DB) {
db.valCreator = valuer.BasicTypeCreator{Creator: valuer.NewUnsafeValue}
db.valCreator = valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue}
}
}

Expand Down Expand Up @@ -84,7 +84,7 @@ func openDB(driver string, db *sql.DB, opts ...DBOption) (*DB, error) {
metaRegistry: model.NewMetaRegistry(),
dialect: dl,
// 可以设为默认,因为原本这里也有默认
valCreator: valuer.BasicTypeCreator{
valCreator: valuer.PrimitiveCreator{
Creator: valuer.NewUnsafeValue,
},
},
Expand Down
4 changes: 2 additions & 2 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func BenchmarkQuerier_Get(b *testing.B) {
}

b.Run("unsafe", func(b *testing.B) {
orm.valCreator = valuer.BasicTypeCreator{
orm.valCreator = valuer.PrimitiveCreator{
Creator: valuer.NewUnsafeValue,
}
for i := 0; i < b.N; i++ {
Expand All @@ -209,7 +209,7 @@ func BenchmarkQuerier_Get(b *testing.B) {
})

b.Run("reflect", func(b *testing.B) {
orm.valCreator = valuer.BasicTypeCreator{
orm.valCreator = valuer.PrimitiveCreator{
Creator: valuer.NewReflectValue,
}
for i := 0; i < b.N; i++ {
Expand Down
2 changes: 1 addition & 1 deletion insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (i *Inserter[T]) Build() (*Query, error) {
i.comma()
}
i.writeString("(")
refVal := i.valCreator.NewBasicTypeValue(val, i.meta)
refVal := i.valCreator.NewPrimitiveValue(val, i.meta)
for j, v := range fields {
fdVal, err := refVal.Field(v.FieldName)
if err != nil {
Expand Down
18 changes: 9 additions & 9 deletions internal/valuer/basic.go → internal/valuer/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ import (
"github.com/ecodeclub/eorm/internal/model"
)

// supportBasicTypeValue 支持基本类型 Value
type supportBasicTypeValue struct {
// primitiveValue 支持基本类型 Value
type primitiveValue struct {
Value
val any
valType reflect.Type
}

// Field 返回字段值
func (s supportBasicTypeValue) Field(name string) (reflect.Value, error) {
func (s primitiveValue) Field(name string) (reflect.Value, error) {
return s.Value.Field(name)
}

// SetColumns 设置列值, 支持基本类型,基于 reflect 与 unsafe Value 封装
func (s supportBasicTypeValue) SetColumns(rows *sql.Rows) error {
func (s primitiveValue) SetColumns(rows *sql.Rows) error {
switch s.valType.Elem().Kind() {
case reflect.Struct:
if scanner, ok := s.val.(sql.Scanner); ok {
Expand All @@ -46,15 +46,15 @@ func (s supportBasicTypeValue) SetColumns(rows *sql.Rows) error {
}
}

// BasicTypeCreator 支持基本类型的 Creator, 基于原生的 Creator 扩展
type BasicTypeCreator struct {
// PrimitiveCreator 支持基本类型的 Creator, 基于原生的 Creator 扩展
type PrimitiveCreator struct {
Creator
}

// NewBasicTypeValue 返回一个封装好的,基于支持基本类型实现的 Value
// NewPrimitiveValue 返回一个封装好的,基于支持基本类型实现的 Value
// 输入 val 必须是一个指向结构体实例的指针,而不能是任何其它类型
func (c BasicTypeCreator) NewBasicTypeValue(val any, meta *model.TableMeta) Value {
return supportBasicTypeValue{
func (c PrimitiveCreator) NewPrimitiveValue(val any, meta *model.TableMeta) Value {
return primitiveValue{
val: val,
Value: c.Creator(val, meta),
valType: reflect.TypeOf(val),
Expand Down
22 changes: 11 additions & 11 deletions internal/valuer/basic_test.go → internal/valuer/primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ import (
"github.com/stretchr/testify/assert"
)

func Test_basicTypeValue_Field(t *testing.T) {
testBasicValueField(t, BasicTypeCreator{Creator: NewUnsafeValue})
testBasicValueField(t, BasicTypeCreator{Creator: NewReflectValue})
func Test_primitiveValue_Field(t *testing.T) {
testPrimitiveValueField(t, PrimitiveCreator{Creator: NewUnsafeValue})
testPrimitiveValueField(t, PrimitiveCreator{Creator: NewReflectValue})
}

func testBasicValueField(t *testing.T, creator BasicTypeCreator) {
func testPrimitiveValueField(t *testing.T, creator PrimitiveCreator) {
meta, err := model.NewMetaRegistry().Get(&test.SimpleStruct{})
if err != nil {
t.Fatal(err)
}
t.Run("zero value", func(t *testing.T) {
entity := &test.SimpleStruct{}
testCases := newValueFieldTestCases(entity)
val := creator.NewBasicTypeValue(entity, meta)
val := creator.NewPrimitiveValue(entity, meta)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := val.Field(tc.field)
Expand All @@ -55,7 +55,7 @@ func testBasicValueField(t *testing.T, creator BasicTypeCreator) {
t.Run("normal value", func(t *testing.T) {
entity := test.NewSimpleStruct(1)
testCases := newValueFieldTestCases(entity)
val := creator.NewBasicTypeValue(entity, meta)
val := creator.NewPrimitiveValue(entity, meta)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := val.Field(tc.field)
Expand Down Expand Up @@ -87,7 +87,7 @@ func testBasicValueField(t *testing.T, creator BasicTypeCreator) {
t.Fatal(err)
}

val := creator.NewBasicTypeValue(&User{}, meta)
val := creator.NewPrimitiveValue(&User{}, meta)
for _, tc := range invalidCases {
t.Run(tc.name, func(t *testing.T) {
v, err := val.Field(tc.field)
Expand Down Expand Up @@ -134,7 +134,7 @@ func testBasicValueField(t *testing.T, creator BasicTypeCreator) {
t.Fatal(err)
}

val := creator.NewBasicTypeValue(cUser, meta)
val := creator.NewPrimitiveValue(cUser, meta)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := val.Field(tc.field)
Expand All @@ -148,7 +148,7 @@ func testBasicValueField(t *testing.T, creator BasicTypeCreator) {
})
}

func Test_basicTypeValue_SetColumn(t *testing.T) {
func Test_primitiveValue_SetColumn(t *testing.T) {
testCases := []struct {
name string
cs map[string][]byte
Expand Down Expand Up @@ -281,8 +281,8 @@ func Test_basicTypeValue_SetColumn(t *testing.T) {
t.Fatal(err)
}
defer func() { _ = db.Close() }()
basicCreator := BasicTypeCreator{Creator: tc.valCreator}
val := basicCreator.NewBasicTypeValue(tc.val, meta)
basicCreator := PrimitiveCreator{Creator: tc.valCreator}
val := basicCreator.NewPrimitiveValue(tc.val, meta)
cols := make([]string, 0, len(tc.cs))
colVals := make([]driver.Value, 0, len(tc.cs))
for k, v := range tc.cs {
Expand Down
2 changes: 1 addition & 1 deletion master_slave_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func OpenMasterSlaveDB(driver string, master *sql.DB, opts ...MasterSlaveDBOptio
metaRegistry: model.NewMetaRegistry(),
dialect: dl,
// 可以设为默认,因为原本这里也有默认
valCreator: valuer.BasicTypeCreator{
valCreator: valuer.PrimitiveCreator{
Creator: valuer.NewUnsafeValue,
},
},
Expand Down
2 changes: 1 addition & 1 deletion sharding_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func OpenShardingDB(driver string, ds sharding.DataSource, opts ...ShardingDBOpt
core: core{
metaRegistry: model.NewMetaRegistry(),
dialect: dl,
valCreator: valuer.BasicTypeCreator{
valCreator: valuer.PrimitiveCreator{
Creator: valuer.NewUnsafeValue,
},
},
Expand Down
4 changes: 2 additions & 2 deletions sharding_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ func (s *ShardingSelector[T]) Get(ctx context.Context) (*T, error) {
return nil, ErrNoRows
}
tp := new(T)
val := s.valCreator.NewBasicTypeValue(tp, s.meta)
val := s.valCreator.NewPrimitiveValue(tp, s.meta)
if err = val.SetColumns(row); err != nil {
return nil, err
}
Expand Down Expand Up @@ -395,7 +395,7 @@ func (s *ShardingSelector[T]) GetMulti(ctx context.Context) ([]*T, error) {
for _, rows := range rowsSlice {
for rows.Next() {
tp := new(T)
val := s.valCreator.NewBasicTypeValue(tp, s.meta)
val := s.valCreator.NewPrimitiveValue(tp, s.meta)
if err = val.SetColumns(rows); err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion update.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (u *Updater[T]) Build() (*Query, error) {
return nil, err
}

u.val = u.valCreator.NewBasicTypeValue(u.table, u.meta)
u.val = u.valCreator.NewPrimitiveValue(u.table, u.meta)
u.args = make([]interface{}, 0, len(u.meta.Columns))

u.writeString("UPDATE ")
Expand Down

0 comments on commit 2679e84

Please sign in to comment.