Skip to content

Commit

Permalink
feat: 优化工厂方法,并且支持非goal app场景下使用本包
Browse files Browse the repository at this point in the history
  • Loading branch information
qbhy committed Jan 31, 2023
1 parent 27c7cfa commit d6a58c4
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 51 deletions.
18 changes: 16 additions & 2 deletions factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,31 @@ package database
import (
"errors"
"github.com/goal-web/contracts"
"github.com/goal-web/database/drivers"
"github.com/goal-web/supports/utils"
)

type Factory struct {
events contracts.EventDispatcher
config contracts.Config
connections map[string]contracts.DBConnection
drivers map[string]contracts.DBConnector
dbConfig Config
}

func NewFactory(config Config, events contracts.EventDispatcher) contracts.DBFactory {
return &Factory{
events: events,
dbConfig: config,
connections: make(map[string]contracts.DBConnection),
drivers: map[string]contracts.DBConnector{
"mysql": drivers.MysqlConnector,
"postgres": drivers.PostgresSqlConnector,
"sqlite": drivers.SqliteConnector,
"clickhouse": drivers.ClickHouseConnector,
},
}
}

func (factory *Factory) Connection(name ...string) contracts.DBConnection {
connection := factory.dbConfig.Default
if len(name) > 0 && name[0] != "" {
Expand All @@ -33,7 +47,7 @@ func (factory *Factory) Extend(name string, driver contracts.DBConnector) {
}

func (factory *Factory) make(name string) contracts.DBConnection {
config := factory.config.Get("database").(Config)
config := factory.dbConfig

if connectionConfig, existsConnection := config.Connections[name]; existsConnection {
driverName := utils.GetStringField(connectionConfig, "driver")
Expand Down
18 changes: 5 additions & 13 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package database

import (
"github.com/goal-web/contracts"
"github.com/goal-web/database/drivers"
"github.com/goal-web/database/migrations"
"github.com/goal-web/database/table"
)

type ServiceProvider struct {
app contracts.Application
migrations contracts.Migrations
}

Expand All @@ -15,25 +16,15 @@ func NewService(migrations contracts.Migrations) contracts.ServiceProvider {
}

func (provider *ServiceProvider) Register(application contracts.Application) {
provider.app = application
application.Instance("migrations", provider.migrations)
application.Singleton("migrations.table", func(config contracts.Config) string {
return config.Get("database").(Config).Migrations
})

application.Singleton("db.factory", func(config contracts.Config) contracts.DBFactory {
events, _ := application.Get("events").(contracts.EventDispatcher)
return &Factory{
events: events,
config: config,
dbConfig: config.Get("database").(Config),
connections: make(map[string]contracts.DBConnection),
drivers: map[string]contracts.DBConnector{
"mysql": drivers.MysqlConnector,
"postgres": drivers.PostgresSqlConnector,
"sqlite": drivers.SqliteConnector,
"clickhouse": drivers.ClickHouseConnector,
},
}
return NewFactory(config.Get("database").(Config), events)
})
application.Singleton("db", func(config contracts.Config, factory contracts.DBFactory) contracts.DBConnection {
return factory.Connection()
Expand All @@ -50,6 +41,7 @@ func (provider *ServiceProvider) Register(application contracts.Application) {
}

func (provider *ServiceProvider) Start() error {
table.SetFactory(provider.app.Get("db.factory").(contracts.DBFactory))
return nil
}

Expand Down
53 changes: 53 additions & 0 deletions table/factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package table

import (
"github.com/goal-web/application"
"github.com/goal-web/contracts"
"github.com/goal-web/querybuilder"
)

var factory contracts.DBFactory

func SetFactory(dbFactory contracts.DBFactory) {
factory = dbFactory
}

func getFactory() contracts.DBFactory {
if factory == nil {
factory = application.Get("db.factory").(contracts.DBFactory)
}
return factory
}

func getTable(name string) *Table {
builder := querybuilder.NewQuery(name)
instance := &Table{
QueryBuilder: builder,
primaryKey: "id",
table: name,
}
builder.Bind(instance)
return instance
}

// Query 将使用默认 connection
func Query(name string) *Table {
return getTable(name).SetConnection(factory.Connection())
}

func FromModel(model contracts.Model) *Table {
return WithConnection(model.GetTable(), model.GetConnection()).SetClass(model.GetClass()).SetPrimaryKey(model.GetPrimaryKey())
}

// WithConnection 使用指定链接
func WithConnection(name string, connection interface{}) *Table {
if connection == "" || connection == nil {
return Query(name)
}
return getTable(name).SetConnection(connection)
}

// WithTX 使用TX
func WithTX(name string, tx contracts.DBTx) contracts.QueryBuilder {
return getTable(name).SetExecutor(tx)
}
37 changes: 1 addition & 36 deletions table/table.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package table

import (
"github.com/goal-web/application"
"github.com/goal-web/contracts"
"github.com/goal-web/querybuilder"
"github.com/goal-web/supports/exceptions"
"github.com/goal-web/supports/utils"
)
Expand All @@ -17,45 +15,12 @@ type Table struct {
class contracts.Class
}

func getTable(name string) *Table {
builder := querybuilder.NewQuery(name)
instance := &Table{
QueryBuilder: builder,
primaryKey: "id",
table: name,
}
builder.Bind(instance)
return instance
}

// Query 将使用默认 connection
func Query(name string) *Table {
return getTable(name).SetConnection(application.Get("db").(contracts.DBConnection))
}

func FromModel(model contracts.Model) *Table {
return WithConnection(model.GetTable(), model.GetConnection()).SetClass(model.GetClass()).SetPrimaryKey(model.GetPrimaryKey())
}

// WithConnection 使用指定链接
func WithConnection(name string, connection interface{}) *Table {
if connection == "" || connection == nil {
return Query(name)
}
return getTable(name).SetConnection(connection)
}

// WithTX 使用TX
func WithTX(name string, tx contracts.DBTx) contracts.QueryBuilder {
return getTable(name).SetExecutor(tx)
}

// SetConnection 参数要么是 contracts.DBConnection 要么是 string
func (table *Table) SetConnection(connection interface{}) *Table {
if conn, ok := connection.(contracts.DBConnection); ok {
table.executor = conn
} else {
table.executor = application.Get("db.factory").(contracts.DBFactory).Connection(utils.ConvertToString(connection, ""))
table.executor = getFactory().Connection(utils.ConvertToString(connection, ""))
}
return table
}
Expand Down
31 changes: 31 additions & 0 deletions tests/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,34 @@ func TestMysqlDatabaseService(t *testing.T) {
assert.True(t, table.Query("users").Count() == 0)

}

func TestMysqlDatabaseWithoutApplication(t *testing.T) {
table.SetFactory(database.NewFactory(database.Config{
Default: "mysql",
Connections: map[string]contracts.Fields{
"mysql": {
"driver": "mysql",
"host": "localhost",
"port": "3306",
"database": "goal",
"username": "root",
"password": "123456",
"charset": "utf8mb4",
"collation": "utf8mb4_unicode_ci",
},
},
Migrations: "migrations",
}, nil))

assert.True(t, table.Query("users").Count() == 0)

user := table.Query("users").Create(contracts.Fields{
"name": "testing",
})
assert.NotNil(t, user)
assert.True(t, user.(contracts.Fields)["name"] == "testing")
assert.True(t, table.Query("users").Count() == 1)
table.Query("users").Where("name", "testing").Delete()
assert.True(t, table.Query("users").Count() == 0)

}

0 comments on commit d6a58c4

Please sign in to comment.