diff --git a/builder.go b/builder.go new file mode 100644 index 0000000..12fd6cc --- /dev/null +++ b/builder.go @@ -0,0 +1,288 @@ +package ormx + +import ( + "context" + "fmt" + "reflect" + "strings" + + sb "github.com/huandu/go-sqlbuilder" +) + +type Builder interface { + Build() (string, []any) +} + +// Build is same with builder.Build, but it will try to inject namespace(which defined in context) filter into where condition in sql +func Build(ctx context.Context, b Builder) (string, []any) { + switch x := b.(type) { + case *sb.UpdateBuilder: + appendNamespaceFilter(ctx, &x.Cond) + case *sb.SelectBuilder: + appendNamespaceFilter(ctx, &x.Cond) + case *sb.DeleteBuilder: + appendNamespaceFilter(ctx, &x.Cond) + } + return b.Build() +} + +func appendNamespaceFilter(ctx context.Context, cond *sb.Cond) *sb.Cond { + v := ctx.Value(namespaceCtxKey{}) + if v == nil { + // no namespace in context + return cond + } + + s, ok := v.(string) + if !ok { + // incorrect namespace value in context, ignore it + return cond + } + + if s == "" || s == "-" || s == "" { + // empty namespace value in context, ignore it + return cond + } + + if shouldIgnoreNamespace(ctx) { + // user-defined force ignore namespace in context + return cond + } + + // append namespace where condition into Cond + cond.E(namespaceColumnName, s) + return cond +} + +// WithNamespace add namespace info into context +func WithNamespace(ctx context.Context, namespace string) context.Context { + return context.WithValue(ctx, namespaceCtxKey{}, namespace) +} + +// IgnnoreNamespace force ormx ignore the namespace info in context, so that Build will not inject namespace filter in sql +func IgnoreNamespace(ctx context.Context) context.Context { + return context.WithValue(ctx, ignoreNamespaceCtxKey{}, true) +} + +func shouldIgnoreNamespace(ctx context.Context) bool { + v := ctx.Value(ignoreNamespaceCtxKey{}) + if v == nil { + return false + } + switch x := v.(type) { + case bool: + return x + case string: + x = strings.ToLower(x) + return x == "1" || x == "true" + case int, int64, uint, uint64, float64, float32: + return x != 0 + } + return false +} + +type namespaceCtxKey struct{} +type ignoreNamespaceCtxKey struct{} + +// WhereFromStruct generate where exprs from data(type of struct), the returned value can be used by builder.Where method +func WhereFromStruct(data any, dst []string) []string { + if data == nil { + return []string{} + } + v := dereferencedValue(reflect.ValueOf(data)) + t := dereferencedType(reflect.TypeOf(data)) + if !v.IsValid() || v.IsZero() { + return []string{} + } + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + if field.IsNil() { + continue + } + name, _ := colNameFromTag(fieldType) + if name == "" { + continue + } + fieldValue := dereferencedValue(field).Interface() + dst = appendWhereExpr(dst, name, fieldValue, fieldType.Tag.Get("op")) + } + return dst +} + +// WhereFromStruct generate where exprs from []KV, the returned value can be used by builder.Where method +func WhereFromKVs(filter KVs, dst []string) []string { + if filter == nil { + return []string{} + } + for _, kv := range filter { + colName := kv.Key + fieldValue := kv.Value + dst = appendWhereExpr(dst, colName, fieldValue, kv.Extra) + } + return dst +} + +func WhereFromIDs(idList []int64, dst []string) []string { + c := sb.NewCond() + t := reflect.TypeOf(idList) + if t.Kind() == reflect.Slice { + dst = append(dst, c.In(primaryKey, Any2Slice(idList)...)) + } + return dst +} + +func WhereFromID(id int64, dst []string) []string { + c := sb.NewCond() + t := reflect.TypeOf(id) + if t.Kind() == reflect.Slice { + dst = append(dst, c.E(primaryKey, id)) + } + return dst +} + +func WhereFrom(filter any, dst []string) []string { + c := sb.NewCond() + if kvs, ok := filter.(KVs); ok { + return WhereFromKVs(kvs, dst) + } + t := dereferencedType(reflect.TypeOf(filter)) + if kind := t.Kind(); kind == reflect.Struct { + return WhereFromStruct(filter, dst) + } else if kind == reflect.Slice { + dst = append(dst, c.In(primaryKey, Any2Slice(filter)...)) + } else { + dst = append(dst, c.E(primaryKey, filter)) + } + return dst +} + +func appendWhereExpr(dst []string, column string, value any, op string) []string { + c := &sb.Cond{ + Args: &sb.Args{}, + } + switch op { + case "": + if dereferencedType(reflect.TypeOf(value)).Kind() == reflect.Slice { + dst = append(dst, c.In(column, Any2Slice(value)...)) + } else { + dst = append(dst, c.E(column, value)) + } + case "e": + dst = append(dst, c.E(column, value)) + case "ne": + dst = append(dst, c.NE(column, value)) + case "gt": + dst = append(dst, c.GreaterThan(column, value)) + case "gte": + dst = append(dst, c.GreaterEqualThan(column, value)) + case "lt": + dst = append(dst, c.LessThan(column, value)) + case "lte": + dst = append(dst, c.LessEqualThan(column, value)) + case "in": + if values := Any2Slice(value); len(values) > 0 { + dst = append(dst, c.In(column, Any2Slice(value)...)) + } else { + dst = append(dst, c.IsNull(column)) + } + case "notin": + dst = append(dst, c.NotIn(column, Any2Slice(value)...)) + case "like": + dst = append(dst, c.Like(column, value)) + case "notlike": + dst = append(dst, c.NotLike(column, value)) + } + return dst +} + +// TableName auto recoganize the table name from data, it will auto prepend the tableNamePrefix which can be set by SetTableNamePrefix to the result. +// - having Table() method, it will call d.Table() to get the table name +// - type of struct, it will use the struct name, and snake case it +// - type of string, return the name. +// - type of other, return fmt.Sprintf("%s", d) +func TableName(d interface{}) string { + if d == nil { + return "" + } +S: + t, ok := d.(interface { + Table() string + }) + + if ok { + return t.Table() + } + + var ( + vt = dereferencedElemType(reflect.TypeOf(d)) + name string + ) + + switch vt.Kind() { + case reflect.String: + return d.(string) + case reflect.Struct: + structName := vt.Name() + name = sb.SnakeCaseMapper(structName) + case reflect.Slice: + d = reflect.New(vt.Elem()).Interface() + goto S + default: + name = fmt.Sprintf("%s", d) + } + + if strings.HasPrefix(name, tableNamePrefix) { + return name + } + + return tableNamePrefix + name +} + +// ColNamesWithTagOpt will column names from structure data, the type of d must be a struct, otherwise will return []string{}. +// +// ColNamesWithTagOpt will try to filter the filter the struct field which having specified in StructField.Tag if is not empty +func ColNamesWithTagOpt(d interface{}, tag string) []string { + vt := reflect.TypeOf(d) + if vt.Kind() == reflect.Ptr { + vt = vt.Elem() + } + if vt.Kind() != reflect.Struct { + return []string{} + } + table := TableName(d) + var cols []string + for i := 0; i < vt.NumField(); i++ { + field := vt.Field(i) + name, after := colNameFromTag(field) + if name == "" { + continue + } + if tag != "" { + if opts := ParseOptionStr(after); opts != nil { + if _, ok := opts[tag]; !ok { + continue + } + } + } + cols = append(cols, table+"."+name) + } + return cols +} + +func colNameFromTag(field reflect.StructField) (string, string) { + if !field.IsExported() { + return "", "" + } + switch field.Type.Kind() { + case reflect.Func, reflect.Chan: + return "", "" + } + name, after, _ := strings.Cut(field.Tag.Get(structTagName), ",") + if name == "-" { + return "", "" + } else if name == "" { + return field.Name, after + } + return name, after +} diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..8887edb --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,154 @@ +package cache + +import ( + "fmt" + "reflect" + "strings" + "sync" + "time" + + lru "github.com/hashicorp/golang-lru" +) + +var ( + lock sync.RWMutex + // Assuming each item is 128 bytes, we allocate 25% of our available memory to the cache. + defaultSize = GetMemoryLimit() / 4 / 128 + cache *lru.TwoQueueCache + finalizers []Finalizer +) + +type Finalizer func(any, any) + +type Value struct { + Data interface{} + Expire time.Time +} + +// Init the cache size +func Init(fs ...Finalizer) (err error) { + size := defaultSize + if size == 0 { + size = 1024 * 1024 * 64 // 64M + } + cache, err = lru.New2Q(int(size)) + if err != nil { + return err + } + finalizers = fs + go tick() + return +} + +func Contains(key interface{}) bool { + return cache.Contains(key) +} + +func Expire() { + lock.Lock() + defer lock.Unlock() + now := time.Now() + for _, key := range cache.Keys() { + if v, ok := cache.Get(key); ok && v.(Value).Expire.Before(now) { + cache.Remove(key) + for _, f := range finalizers { + f(key, v.(Value).Data) + } + } + } +} + +func getByKey(key interface{}) (interface{}, bool) { + lock.Lock() + defer lock.Unlock() + value, ok := cache.Get(key) + if !ok { + return nil, false + } + ins := value.(Value) + if ins.Expire.Before(time.Now()) { + cache.Remove(key) + for _, f := range finalizers { + f(key, ins.Data) + } + return nil, false + } + return ins.Data, true +} + +func Get(keys ...interface{}) (interface{}, bool) { + key := joinSlice(keys, "/") + return getByKey(key) +} + +func Set(ttl time.Duration, keyAndValue ...any) { + if len(keyAndValue) <= 2 { + return + } + + keys := keyAndValue[:len(keyAndValue)-1] + value := keyAndValue[len(keyAndValue)-1] + key := joinSlice(keys, "/") + lock.Lock() + defer lock.Unlock() + cache.Add(key, Value{ + Data: value, + Expire: time.Now().Add(ttl), + }) +} + +func Remove(keys ...any) { + key := joinSlice(keys, "/") + lock.Lock() + defer lock.RUnlock() + value, ok := cache.Get(key) + if !ok { + return + } + cache.Remove(key) + for _, f := range finalizers { + f(key, value.(Value).Data) + } +} + +func Try(dest any, fallback func() error, ttl time.Duration, keys ...any) error { + key := joinSlice(keys, "/") + value, ok := getByKey(key) + if !ok { + if err := fallback(); err != nil { + return err + } + cache.Add(key, Value{ + Data: reflect.ValueOf(dest).Elem().Interface(), + Expire: time.Now().Add(ttl), + }) + } else if dest == nil { + return fmt.Errorf("dest is nil") + } else { + reflect.ValueOf(dest).Elem().Set(reflect.ValueOf(value)) + } + return nil +} + +func Len() int { + return cache.Len() +} + +func tick() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for range ticker.C { + Expire() + } +} + +func joinSlice[T int | string | int64 | int32 | int16 | int8 | uint32 | uint64 | uint16 | uint8 | float64 | float32 | any](data []T, split string) string { + builder := &strings.Builder{} + for i, item := range data { + builder.WriteString(fmt.Sprintf("%v", item)) + if i < len(data)-1 { + builder.WriteString(split) + } + } + return builder.String() +} diff --git a/cache/mem.go b/cache/mem.go new file mode 100644 index 0000000..d3f76b5 --- /dev/null +++ b/cache/mem.go @@ -0,0 +1,87 @@ +package cache + +import ( + "fmt" + "os" + "path" + "strconv" + "strings" +) + +// GetMemoryLimit returns cgroup memory limit +func GetMemoryLimit() int64 { + // Try determining the amount of memory inside docker container. + // See https://stackoverflow.com/questions/42187085/check-mem-limit-within-a-docker-container + // + // Read memory limit according to https://unix.stackexchange.com/questions/242718/how-to-find-out-how-much-memory-lxc-container-is-allowed-to-consume + // This should properly determine the limit inside lxc container. + // See https://github.com/VictoriaMetrics/VictoriaMetrics/issues/84 + n, err := getMemStat("memory.limit_in_bytes") + if err == nil { + return n + } + n, err = getMemStatV2("memory.max") + if err != nil { + return 0 + } + return n +} + +func getMemStatV2(statName string) (int64, error) { + // See https: //www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html#memory-interface-files + return getStatGeneric(statName, "/sys/fs/cgroup", "/proc/self/cgroup", "") +} + +func getMemStat(statName string) (int64, error) { + return getStatGeneric(statName, "/sys/fs/cgroup/memory", "/proc/self/cgroup", "memory") +} + +func getStatGeneric(statName, sysfsPrefix, cgroupPath, cgroupGrepLine string) (int64, error) { + data, err := getFileContents(statName, sysfsPrefix, cgroupPath, cgroupGrepLine) + if err != nil { + return 0, err + } + data = strings.TrimSpace(data) + n, err := strconv.ParseInt(data, 10, 64) + if err != nil { + return 0, fmt.Errorf("cannot parse %q: %w", cgroupPath, err) + } + return n, nil +} + +func getFileContents(statName, sysfsPrefix, cgroupPath, cgroupGrepLine string) (string, error) { + filepath := path.Join(sysfsPrefix, statName) + data, err := os.ReadFile(filepath) + if err == nil { + return string(data), nil + } + cgroupData, err := os.ReadFile(cgroupPath) + if err != nil { + return "", err + } + subPath, err := grepFirstMatch(string(cgroupData), cgroupGrepLine, 2, ":") + if err != nil { + return "", fmt.Errorf("cannot find cgroup path for %q in %q: %w", cgroupGrepLine, cgroupPath, err) + } + filepath = path.Join(sysfsPrefix, subPath, statName) + data, err = os.ReadFile(filepath) + if err != nil { + return "", err + } + return string(data), nil +} + +// grepFirstMatch searches match line at data and returns item from it by index with given delimiter. +func grepFirstMatch(data string, match string, index int, delimiter string) (string, error) { + lines := strings.Split(string(data), "\n") + for _, s := range lines { + if !strings.Contains(s, match) { + continue + } + parts := strings.Split(s, delimiter) + if index < len(parts) { + return strings.TrimSpace(parts[index]), nil + } + } + return "", fmt.Errorf("cannot find %q in %q", match, data) +} diff --git a/delete.go b/delete.go new file mode 100644 index 0000000..9c8b83d --- /dev/null +++ b/delete.go @@ -0,0 +1,52 @@ +package ormx + +import ( + "context" + + sb "github.com/huandu/go-sqlbuilder" + "github.com/jmoiron/sqlx" +) + +// DeleteWhere delete rows that match the filter from the given table +func DeleteWhere(ctx context.Context, table string, filter KVs) error { + return DeleteWhereTx(ctx, nil, table, filter) +} + +// DeleteWhereTx delete rows that match the filter in transaction from the given table +func DeleteWhereTx(ctx context.Context, tx *sqlx.Tx, table string, filter KVs) error { + builder := sb.NewDeleteBuilder().DeleteFrom(table) + builder = builder.Where(WhereFromKVs(filter, nil)...) + var ( + sql, args = Build(ctx, builder) + err error + ) + if tx == nil { + _, err = Exec(ctx, sql, args...) + } else { + _, err = ExecTx(ctx, tx, sql, args...) + } + return err +} + +// DeleteWhere delete rows by id from the given table +func DeleteByID(ctx context.Context, table string, id ...any) error { + return DeleteByIDTx(ctx, nil, table, id...) +} + +// DeleteWhere delete rows by id in transaction from the table +func DeleteByIDTx(ctx context.Context, tx *sqlx.Tx, table string, id ...any) error { + builder := sb.NewDeleteBuilder().DeleteFrom(table) + builder = builder.Where(WhereFrom(id, nil)...) + var ( + err error + sql string + args []any + ) + sql, args = Build(ctx, builder) + if tx == nil { + _, err = Exec(ctx, sql, args...) + } else { + _, err = ExecTx(ctx, tx, sql, args...) + } + return err +} diff --git a/exec.go b/exec.go new file mode 100644 index 0000000..48c02d3 --- /dev/null +++ b/exec.go @@ -0,0 +1,115 @@ +package ormx + +import ( + "context" + "database/sql/driver" + + "github.com/jmoiron/sqlx" +) + +var ( + p DBProvider +) + +// DBProvider +type DBProvider func(isMaster bool) *sqlx.DB + +// RunTxContext execute a transiction +func RunTxContext(ctx context.Context, f func(ctx context.Context, tx *sqlx.Tx) error) error { + db := Master() + tx, err := db.BeginTxx(ctx, nil) + if err != nil { + return err + } + + if err := f(ctx, tx); err != nil { + if err := tx.Rollback(); err != nil { + return err + } + return err + } + + return tx.Commit() +} + +// Exec execute a sql on master DB +func Exec(ctx context.Context, sql string, args ...interface{}) (driver.Result, error) { + log.Printf(ctx, InfoLevel, "Executing sql: %s, args: [%v]", sql, args) + emitMetric(ctx, sql) + return Master().ExecContext(ctx, sql, args...) +} + +// Exec execute a sql in transaction +func ExecTx(ctx context.Context, tx *sqlx.Tx, sql string, args ...interface{}) (driver.Result, error) { + log.Printf(ctx, InfoLevel, "Executing sql: %s, args: [%v]", sql, args) + emitMetric(ctx, sql) + return tx.ExecContext(ctx, sql, args...) +} + +// Select will query data into dest with raw sql and args. +// +// it will auto query from master if the context having FromMaster +func Select(ctx context.Context, dest interface{}, sql string, args ...interface{}) error { + var ( + db *sqlx.DB + ) + if isFromMaster(ctx) { + db = Master() + log.Printf(ctx, InfoLevel, "Selecting on master: %s, args: [%v]", sql, args) + } else { + db = Slave() + log.Printf(ctx, DebugLevel, "Selecting on slave: %s, args: [%v]", sql, args) + } + emitMetric(ctx, sql) + return db.SelectContext(ctx, dest, sql, args...) +} + +// Select will query data into dest with raw sql and args. +// +// it will auto query from master if the context having FromMaster +func SelectTx(ctx context.Context, tx *sqlx.Tx, dest interface{}, sql string, args ...interface{}) error { + log.Printf(ctx, InfoLevel, "Selecting sql: %s, args: [%v]", sql, args) + emitMetric(ctx, sql) + return tx.SelectContext(ctx, dest, sql, args...) +} + +// Get will get one data into dest with raw sql and args. +// +// it will auto query from master if the context having FromMaster +func Get(ctx context.Context, dest interface{}, sql string, args ...interface{}) error { + var ( + db *sqlx.DB + ) + if isFromMaster(ctx) { + db = Master() + log.Printf(ctx, InfoLevel, "Getting by sql: %s, args: [%v]", sql, args) + } else { + db = Slave() + log.Printf(ctx, DebugLevel, "Getting by sql: %s, args: [%v]", sql, args) + } + emitMetric(ctx, sql) + return db.GetContext(ctx, dest, sql, args...) +} + +// Get will get one data from tx by using raw sql and args. +func GetTx(ctx context.Context, tx *sqlx.Tx, dest interface{}, sql string, args ...interface{}) error { + log.Printf(ctx, InfoLevel, "Getting by sql: %s, args: [%v]", sql, args) + emitMetric(ctx, sql) + return tx.GetContext(ctx, dest, sql, args...) +} + +// Master return master *sqlx.DB which returned by DBProvider, panic if DBProvider is not Initilized +func Master() *sqlx.DB { + if p == nil { + panic("db getter is nil, call ormx.Init to initilaze the DBGetter") + } + return p(true) +} + +// Master return slave *sqlx.DB which returned by DBProvider, panic if DBProvider is not Initilized +func Slave() *sqlx.DB { + if p == nil { + panic("db getter is nil, call ormx.Init to initilaze the DBGetter") + } + return p(false) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..eb4e0c3 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/cloudfly/ormx + +go 1.22.2 + +require github.com/hashicorp/golang-lru v1.0.2 + +require ( + github.com/huandu/go-sqlbuilder v1.28.1 // indirect + github.com/huandu/xstrings v1.4.0 // indirect + github.com/jmoiron/sqlx v1.4.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a32b6fd --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= +github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/huandu/go-assert v1.1.6/go.mod h1:JuIfbmYG9ykwvuxoJ3V8TB5QP+3+ajIA54Y44TmkMxs= +github.com/huandu/go-sqlbuilder v1.28.1 h1:unk88CvOCvnUrOebhB0Q8KXtTkKENeYnT/L2prchnik= +github.com/huandu/go-sqlbuilder v1.28.1/go.mod h1:mS0GAtrtW+XL6nM2/gXHRJax2RwSW1TraavWDFAc1JA= +github.com/huandu/xstrings v1.4.0 h1:D17IlohoQq4UcpqD7fDk80P7l+lwAmlFaBHgOipl2FU= +github.com/huandu/xstrings v1.4.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= +github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/insert.go b/insert.go new file mode 100644 index 0000000..8331fd8 --- /dev/null +++ b/insert.go @@ -0,0 +1,180 @@ +package ormx + +import ( + "context" + "database/sql/driver" + "fmt" + "reflect" + + sb "github.com/huandu/go-sqlbuilder" + "github.com/jmoiron/sqlx" +) + +// InsertIgnore insert new data into database and ingore the rows on duplicate keys +func InsertIgnore(ctx context.Context, table string, data ...any) error { + return InsertIgnoreTx(ctx, nil, table, data...) +} + +// InsertIgnoreTx insert new data into database and ingore the rows on duplicate keys using transaction +func InsertIgnoreTx(ctx context.Context, tx *sqlx.Tx, table string, data ...any) error { + if len(data) == 0 { + return nil + } + var ( + err error + ) + if table == "" { + table = TableName(data) + } + ib, err := NewInsertBuilderFromStruct(table, data[0]) + if err != nil { + return fmt.Errorf("create insert builder from structure error: %w", err) + } + ib = ib.InsertIgnoreInto(table) + sql, args := Build(ctx, ib) + + if tx == nil { + _, err = Exec(ctx, sql, args...) + } else { + _, err = ExecTx(ctx, tx, sql, args...) + } + if err != nil { + return fmt.Errorf("exec error: %w", err) + } + return nil +} + +// InsertManyTx insert rows in transaction, the all data type should be same structure. +func InsertMany(ctx context.Context, table string, data ...any) error { + return InsertManyTx(ctx, nil, table, data...) +} + +// InsertManyTx insert rows in transaction, the all data type should be same structure. +func InsertManyTx(ctx context.Context, tx *sqlx.Tx, table string, data ...any) error { + if len(data) == 0 { + return nil + } + var ( + err error + ) + if table == "" { + table = TableName(data) + } + ib, err := NewInsertBuilderFromStruct(table, data[0]) + if err != nil { + return fmt.Errorf("create insert builder from structure error: %w", err) + } + sql, args := Build(ctx, ib) + + if tx == nil { + _, err = Exec(ctx, sql, args...) + } else { + _, err = ExecTx(ctx, tx, sql, args...) + } + if err != nil { + return fmt.Errorf("exec error: %w", err) + } + return nil +} + +// InsertOneTx insert rows into table, the data type should be structure. +func InsertOne(ctx context.Context, table string, data any) (int64, error) { + return InsertOneTx(ctx, nil, table, data) +} + +// InsertOneTx insert rows in transaction, the data type should be structure. +func InsertOneTx(ctx context.Context, tx *sqlx.Tx, table string, data any) (int64, error) { + if data == nil { + return 0, nil + } + var ( + err error + id int64 + r driver.Result + ) + if table == "" { + table = TableName(data) + } + ib, err := NewInsertBuilderFromStruct(table, data) + if err != nil { + return 0, fmt.Errorf("create insert builder from structure error: %w", err) + } + sql, args := Build(ctx, ib) + + if tx == nil { + r, err = Exec(ctx, sql, args...) + } else { + r, err = ExecTx(ctx, tx, sql, args...) + } + if err != nil { + return 0, fmt.Errorf("exec error: %w", err) + } + id, err = r.LastInsertId() + if err != nil { + return 0, fmt.Errorf("get last insert id: %w", err) + } + return id, nil +} + +// NewInsertBuilderFromStruct 使用一组数据 data 数据构建 INSERT SQL Builder +// data 的数据类型必须一致,并且属于 struct 类型。 +// data 数据类型中,定义了 fieldtag:"insert" 信息的字段才会被插入。 +func NewInsertBuilderFromStruct(table string, data ...any) (*sb.InsertBuilder, error) { + if len(data) <= 0 { + return nil, fmt.Errorf("no data to insert") + } + if table == "" { + table = TableName(data[0]) + } + + // 使用第一个数据的类型,获取列名信息。 + var ( + ib = sb.NewInsertBuilder().InsertInto(table) + t = dereferencedType(reflect.TypeOf(data[0])) + cols []string + fieldTags = make([]string, t.NumField()) + ) + for i := 0; i < t.NumField(); i++ { + fieldType := t.Field(i) + name, after := colNameFromTag(fieldType) + if name == "" { + continue + } + opts := ParseOptionStr(after) + if _, ok := opts["insert"]; !ok { + continue + } + cols = append(cols, name) + + if t := opts["type"]; t != "" { + fieldTags[i] = t + } else { + fieldTags[i] = "-" + } + } + + if len(cols) == 0 { + return nil, fmt.Errorf(`no insert field defined in '%s' type, defined db:",insert" for insert field`, t.Name()) + } + + ib.Cols(cols...) + + for _, item := range data { + var ( + v = dereferencedValue(reflect.ValueOf(item)) + vals []any + ) + if !v.IsValid() || v.IsZero() { + continue + } + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if fieldTags[i] != "" { + vals = append(vals, convertValueByDBType(dereferencedValue(field).Interface(), fieldTags[i])) + } + } + ib.Values(vals...) + } + + return ib, nil +} diff --git a/log.go b/log.go new file mode 100644 index 0000000..c6d94e8 --- /dev/null +++ b/log.go @@ -0,0 +1,42 @@ +package ormx + +import "context" + +type Level int8 + +const ( + // DebugLevel defines debug log level. + DebugLevel Level = iota + // InfoLevel defines info log level. + InfoLevel + // WarnLevel defines warn log level. + WarnLevel + // ErrorLevel defines error log level. + ErrorLevel + // FatalLevel defines fatal log level. + FatalLevel + // PanicLevel defines panic log level. + PanicLevel + // NoLevel defines an absent log level. + NoLevel + // Disabled disables the logger. + Disabled + + // TraceLevel defines trace log level. + TraceLevel Level = -1 + // Values less than TraceLevel are handled as numbers. +) + +type Logger interface { + Printf(context.Context, Level, string, ...any) +} + +type nopLogger struct{} + +func (l nopLogger) Printf(ctx context.Context, level Level, format string, args ...any) {} + +var log Logger = nopLogger{} + +func SetLogger(l Logger) { + log = l +} diff --git a/metric.go b/metric.go new file mode 100644 index 0000000..0f2e27a --- /dev/null +++ b/metric.go @@ -0,0 +1,50 @@ +package ormx + +import ( + "context" + "strings" +) + +type MetricHandler interface { + Emit(context.Context, string, bool) +} + +var metricHandler MetricHandler + +func emitMetric(ctx context.Context, sql string) { + if metricHandler == nil { + return + } + + var ( + table string + write bool + ) + + if i := strings.Index(sql, " FROM "); i > 0 { + write = false + sql = strings.TrimSpace(sql[i+6:]) + } else if j := strings.Index(sql, " INTO "); j > 0 { + write = true + sql = strings.TrimSpace(sql[j+6:]) + } else if k := strings.Index(sql, "UPDATE "); k >= 0 { + write = true + sql = strings.TrimSpace(sql[k+7:]) + } else { + return + } + + tableName, subSQL, ok := strings.Cut(sql, " ") + if ok { + table = tableName + } else { + emitMetric(ctx, subSQL) + return + } + + metricHandler.Emit(ctx, table, write) +} + +func SetMetricHandler(h MetricHandler) { + metricHandler = h +} diff --git a/orm.go b/orm.go new file mode 100644 index 0000000..3cbb1f3 --- /dev/null +++ b/orm.go @@ -0,0 +1,64 @@ +package ormx + +import ( + "context" + "fmt" +) + +var ( + tableNamePrefix = "" + structTagName = "db" + namespaceColumnName = "namespace" + primaryKey = "id" +) + +// Init the ormx, setting the sqlx.DB getter and common table name prefix +func Init(provider DBProvider, tablePrefix string) { + p = provider + tableNamePrefix = tablePrefix +} + +// SetStructTagName set the tag name in Go Struct Tag, in which specify the ormx options, default is 'db' +func SetStructTagName(name string) { + structTagName = name +} + +// SetPrimaryKey set the primary column name, default is 'id' +func SetPrimaryKey(name string) { + if name != "" { + primaryKey = name + } +} + +// SetNamespaceColumnName set the common namespace colunm name, default is 'namespace'; +// +// ormx will auto inject namespace where condition into sql.// Set to empty string disable this feature +func SetNamespaceColumnName(name string) { + namespaceColumnName = name +} + +type masterCtxKey struct{} + +func isFromMaster(ctx context.Context) bool { + return fmt.Sprintf("%v", ctx.Value(masterCtxKey{})) == "true" +} + +// FromMaster force ormx execute sql on master instance when called by this context +func FromMaster(ctx context.Context) context.Context { + return context.WithValue(ctx, masterCtxKey{}, "true") +} + +// FromMaster force ormx execute sql on slave instance when called by this context +func FromSlave(ctx context.Context) context.Context { + return context.WithValue(ctx, masterCtxKey{}, "false") +} + +func convertValueByDBType(v any, tag string) any { + switch tag { + case "timestamp": + if t, ok := Any2Time(v); ok { + return t + } + } + return v +} diff --git a/select.go b/select.go new file mode 100644 index 0000000..84f0426 --- /dev/null +++ b/select.go @@ -0,0 +1,149 @@ +package ormx + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/cloudfly/ormx/cache" + sb "github.com/huandu/go-sqlbuilder" +) + +func GetByID(ctx context.Context, dst interface{}, table string, id int64) error { + if table == "" { + table = TableName(dst) + } + + if !isFromMaster(ctx) { + // Not reading data from the primary database indicates that some delay is tolerable. + // Attempt to read from the local cache. + if v, ok := cache.Get(table, id); ok { + if content, ok := v.([]byte); ok { + if err := json.Unmarshal(content, dst); err == nil { + return nil + } else { + // Deserialization error indicates that the data is unusable. Delete it directly. + cache.Remove(table, id) + } + } + } + } + + b, err := NewSelectBuilderFromStruct(table, dst) + if err != nil { + return fmt.Errorf("create select builder error:%w", err) + } + b = b.Where(WhereFrom(id, nil)...) + + var ( + statement string + args []any + ) + statement, args = Build(ctx, b) + + if err := Get(ctx, dst, statement, args...); err != nil { + return err + } + content, err := json.Marshal(dst) + if err != nil { + log.Printf(ctx, WarnLevel, "Failed to marshal data for caching: %w, data: %+v", err.Error(), dst) + // 忽略序列化错误,顶多就是无法cache,无关紧要 + return nil + } + cache.Set(time.Second*10, table, id, content) + return nil +} + +// GetWhere 使用自定义条件跟新数据 +func GetWhere(ctx context.Context, dst interface{}, table string, filter KVs) error { + if table == "" { + table = TableName(dst) + } + builder := sb.NewStruct(dst).SelectFrom(table).Where(WhereFrom(filter, nil)...) + sql, args := Build(ctx, builder) + return Get(ctx, dst, sql, args...) +} + +// GetWhere 使用自定义条件跟新数据 +func SelectWhere(ctx context.Context, dst interface{}, table string, filter KVs) error { + if table == "" { + table = TableName(dst) + } + builder, err := NewSelectBuilderFromStruct(table, dst) + if err != nil { + return fmt.Errorf("new select builder error: %w", err) + } + builder = builder.Where(WhereFrom(filter, nil)...) + sql, args := Build(ctx, builder) + + return Select(ctx, dst, sql, args...) +} + +// Count select the count of rows in table which match the filter condition +func Count(ctx context.Context, table string, filter KVs) (int64, error) { + total := sql.NullInt64{} + b := sb.NewSelectBuilder().Select("COUNT(1) as total").From(table) + b = b.Where(WhereFromKVs(filter, nil)...) + sql, args := Build(ctx, b) + err := Get(ctx, &total, sql, args...) + if IsNotFound(err) { + err = nil + } + return total.Int64, err +} + +// Exist return true if the at least one row found in table by using where condition +func Exist(ctx context.Context, table string, filter KVs) (bool, error) { + n := sql.NullInt64{} + b := sb.NewSelectBuilder().Select("1").From(table).Limit(1) + b = b.Where(WhereFrom(filter, nil)...) + statement, args := Build(ctx, b) + err := Get(ctx, &n, statement, args...) + if err != nil { + if IsNotFound(err) { + return false, nil + } + return false, err + } + return true, nil +} + +// NewSelectBuilderFromStruct create select sql builder by data +func NewSelectBuilderFromStruct(table string, data any) (*sb.SelectBuilder, error) { + if table == "" { + table = TableName(data) + } + b := sb.NewSelectBuilder().From(table) + if data == nil { + b = b.Select("*") + return b, nil + } + t := dereferencedType(reflect.TypeOf(data)) + + if t.Kind() == reflect.Slice { + t = t.Elem() + } + + cols := make([]string, 0, t.NumField()) + for i := 0; i < t.NumField(); i++ { + fieldType := t.Field(i) + name, after := colNameFromTag(fieldType) + if name == "" { + continue + } + opts := ParseOptionStr(after) + if optv, ok := opts["select"]; ok && (optv == "-" || optv == "false") { + continue + } + cols = append(cols, name) + } + if len(cols) == 0 { + b = b.Select("*") + } else { + b = b.Select(cols...) + } + return b, nil +} diff --git a/update.go b/update.go new file mode 100644 index 0000000..5c79ef5 --- /dev/null +++ b/update.go @@ -0,0 +1,97 @@ +package ormx + +import ( + "context" + "database/sql/driver" + "reflect" + + sb "github.com/huandu/go-sqlbuilder" + "github.com/jmoiron/sqlx" +) + +// PatchByID updates the data by id in the table. +func PatchByID(ctx context.Context, table string, id int64, data any) error { + return PatchByIDTx(ctx, nil, table, id, data) +} + +// PatchByIDTx updates the data by id in the table using a transaction. +func PatchByIDTx(ctx context.Context, tx *sqlx.Tx, table string, id int64, data any) error { + ub, ok := NewUpdateBuilderFromStruct(data, table) + if !ok { + return nil + } + ub = ub.Where(WhereFrom(id, nil)...) + var ( + sql string + args []any + err error + ) + sql, args = Build(ctx, ub) + + if tx == nil { + _, err = Exec(ctx, sql, args...) + } else { + _, err = ExecTx(ctx, tx, sql, args...) + } + return err +} + +// PatchWhere updates the data that match the filter in the table. +// The filter is used as the condition and can be of type KVs, or struct. +func PatchWhere(ctx context.Context, table string, data any, filter any) (int64, error) { + return PatchWhereTx(ctx, nil, table, data, filter) +} + +// PatchWhereTx updates the data that matchthe filter in the table using a transaction. +// The filter is used as the condition and can be of type KVs, struct, []int64, int64. +func PatchWhereTx(ctx context.Context, tx *sqlx.Tx, table string, data any, filter any) (int64, error) { + ub, ok := NewUpdateBuilderFromStruct(data, table) + if !ok { + return 0, nil + } + ub = ub.Where(WhereFrom(filter, nil)...) + var ( + err error + sql string + args []any + r driver.Result + ) + sql, args = Build(ctx, ub) + if tx == nil { + r, err = Exec(ctx, sql, args...) + } else { + r, err = ExecTx(ctx, tx, sql, args...) + } + if err != nil { + return 0, err + } + return r.RowsAffected() +} + +// NewUpdateBuilderFromStruct 使用 data 数据定义 update builder +func NewUpdateBuilderFromStruct(data any, table string) (*sb.UpdateBuilder, bool) { + if table == "" { + table = TableName(data) + } + ub := sb.NewUpdateBuilder().Update(table) + v := dereferencedValue(reflect.ValueOf(data)) + t := dereferencedType(reflect.TypeOf(data)) + assigned := false + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := t.Field(i) + if field.IsNil() { + continue + } + + name, after := colNameFromTag(fieldType) + if name == "" { + continue + } + opts := ParseOptionStr(after) + fieldValue := convertValueByDBType(dereferencedValue(field).Interface(), opts["type"]) + ub = ub.SetMore(ub.Assign(name, fieldValue)) + assigned = true + } + return ub, assigned +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..b81c4ec --- /dev/null +++ b/util.go @@ -0,0 +1,163 @@ +package ormx + +import ( + "database/sql" + "encoding/json" + "errors" + "reflect" + "strconv" + "strings" + "time" +) + +func Any2Time(i any) (time.Time, bool) { + if i == nil { + return time.Time{}, false + } + var num int64 + switch value := i.(type) { + case string: + if n, err := strconv.ParseInt(value, 10, 64); err == nil { + num = n + } + case float64: + num = int64(value) + case float32: + num = int64(value) + case int64: + num = int64(value) + case int32: + num = int64(value) + case int: + num = int64(value) + case uint64: + num = int64(value) + case uint32: + num = int64(value) + case uint: + num = int64(value) + case json.Number: + if n, err := value.Int64(); err == nil { + num = n + } + } + if num == 0 { + return time.Time{}, false + } + for num >= 9000000000 { + num /= 10 + } + return time.Unix(num, 0), true +} + +func dereferencedValue(v reflect.Value) reflect.Value { + for k := v.Kind(); k == reflect.Ptr || k == reflect.Interface; k = v.Kind() { + v = v.Elem() + } + + return v +} + +func dereferencedType(t reflect.Type) reflect.Type { + for k := t.Kind(); k == reflect.Ptr || k == reflect.Interface; k = t.Kind() { + t = t.Elem() + } + return t +} + +func dereferencedElemType(t reflect.Type) reflect.Type { + for k := t.Kind(); k == reflect.Ptr || k == reflect.Interface || k == reflect.Slice; k = t.Kind() { + t = t.Elem() + } + return t +} + +func Any2Slice(data any) []any { + v := reflect.ValueOf(data) + t := reflect.TypeOf(data) + if t.Kind() != reflect.Slice { + return []any{data} + } + var iface []interface{} + for i := 0; i < v.Len(); i++ { + iface = append(iface, v.Index(i).Interface()) + } + return iface +} + +type M map[string]any + +type KV struct { + Key string + Value any + Extra string +} + +type KVs []KV + +// KVsFromMap generate KVs from map +func KVsFromMap(dst KVs, filter map[string]any) KVs { + for k, v := range filter { + switch reflect.TypeOf(v).Kind() { + case reflect.Slice: + dst = append(dst, KV{Key: k, Value: v, Extra: "in"}) + default: + dst = append(dst, KV{Key: k, Value: v, Extra: "e"}) + } + } + return dst +} + +// IsNotFound 判断查询错误是否是 未找到错误 +func IsNotFound(err error) bool { + return errors.Is(err, sql.ErrNoRows) +} + +// IsDuplicate 判断查询错误是否是 未找到错误 +func IsDuplicate(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "Error 1062: Duplicate") +} + +// ParseOptionStr will decode key-value data from a string which format like k1:v1,k2:v2,k3:v3. +// it will always return a non-nil value map +// such as: +// - k1:v1,k2:v2 will parsed to {"k1":"v1","k2":"v2"} +// - k1,k2 will parsed to {"k1":"","k2":""} +// - k1:v2,k2 will parsed to {"k1":"v2","k2":""} +func ParseOptionStr(str string) map[string]string { + options := map[string]string{} + + kb, vb, stage := &strings.Builder{}, &strings.Builder{}, 'k' + for i := 0; i < len(str); i++ { + b := kb + if stage == 'v' { + b = vb + } + if str[i] == '\\' && i < len(str)-1 && str[i+1] == ',' { + b.WriteByte(',') + i++ + } + if str[i] == ':' { + stage = 'v' + continue + } else if str[i] == ',' { + if k, v := kb.String(), vb.String(); k != "" { + options[k] = v + } + stage = 'k' + kb.Reset() + vb.Reset() + } else { + b.WriteByte(str[i]) + } + } + + if k, v := kb.String(), vb.String(); k != "" { + options[k] = v + } + + return options +}