Skip to content

Commit

Permalink
dialect/sql: support capturing predicates in selectors
Browse files Browse the repository at this point in the history
This allows custom predicates mutating the root querying and still respect the AND/OR/NOT semantics
  • Loading branch information
a8m committed Jun 21, 2023
1 parent 4787899 commit 808edd1
Show file tree
Hide file tree
Showing 132 changed files with 475 additions and 2,600 deletions.
28 changes: 28 additions & 0 deletions dialect/sql/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2146,6 +2146,7 @@ type Selector struct {
selection []selection
from []TableView
joins []join
collected [][]*Predicate
where *Predicate
or bool
not bool
Expand Down Expand Up @@ -2385,8 +2386,35 @@ func (s *Selector) Offset(offset int) *Selector {
return s
}

// CollectPredicates indicates the appended predicated should be collected
// and not appended to the `WHERE` clause.
func (s *Selector) CollectPredicates() *Selector {
s.collected = append(s.collected, []*Predicate{})
return s
}

// CollectedPredicates returns the collected predicates.
func (s *Selector) CollectedPredicates() []*Predicate {
if len(s.collected) == 0 {
return nil
}
return s.collected[len(s.collected)-1]
}

// UncollectedPredicates stop collecting predicates.
func (s *Selector) UncollectedPredicates() *Selector {
if len(s.collected) > 0 {
s.collected = s.collected[:len(s.collected)-1]
}
return s
}

// Where sets or appends the given predicate to the statement.
func (s *Selector) Where(p *Predicate) *Selector {
if len(s.collected) > 0 {
s.collected[len(s.collected)-1] = append(s.collected[len(s.collected)-1], p)
return s
}
if s.not {
p = Not(p)
s.not = false
Expand Down
57 changes: 57 additions & 0 deletions dialect/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,63 @@ func FieldContainsFold(name string, substr string) func(*Selector) {
}
}

// AndPredicates returns a new predicate for joining multiple generated predicates with AND between them.
func AndPredicates[P ~func(*Selector)](predicates ...P) func(*Selector) {
return func(s *Selector) {
s.CollectPredicates()
for _, p := range predicates {
p(s)
}
collected := s.CollectedPredicates()
s.UncollectedPredicates()
switch len(collected) {
case 0:
case 1:
s.Where(collected[0])
default:
s.Where(And(collected...))
}
}
}

// OrPredicates returns a new predicate for joining multiple generated predicates with OR between them.
func OrPredicates[P ~func(*Selector)](predicates ...P) func(*Selector) {
return func(s *Selector) {
s.CollectPredicates()
for _, p := range predicates {
p(s)
}
collected := s.CollectedPredicates()
s.UncollectedPredicates()
switch len(collected) {
case 0:
case 1:
s.Where(collected[0])
default:
s.Where(Or(collected...))
}
}
}

// NotPredicates wraps the generated predicates with NOT. For example, NOT(P), NOT((P1 AND P2)).
func NotPredicates[P ~func(*Selector)](predicates ...P) func(*Selector) {
return func(s *Selector) {
s.CollectPredicates()
for _, p := range predicates {
p(s)
}
collected := s.CollectedPredicates()
s.UncollectedPredicates()
switch len(collected) {
case 0:
case 1:
s.Where(Not(collected[0]))
default:
s.Where(Not(And(collected...)))
}
}
}

// ColumnCheck is a function that verifies whether the
// specified column exists within the given table.
type ColumnCheck func(table, column string) error
Expand Down
23 changes: 3 additions & 20 deletions entc/gen/template/dialect/sql/predicate.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,13 @@ in the LICENSE file in the root directory of this source tree.
{{- end }}

{{ define "dialect/sql/predicate/and" -}}
func(s *sql.Selector) {
s1 := s.Clone().SetP(nil)
for _, p := range predicates {
p(s1)
}
s.Where(s1.P())
}
sql.AndPredicates(predicates...)
{{- end }}

{{ define "dialect/sql/predicate/or" -}}
func(s *sql.Selector) {
s1 := s.Clone().SetP(nil)
for i, p := range predicates {
if i > 0 {
s1.Or()
}
p(s1)
}
s.Where(s1.P())
}
sql.OrPredicates(predicates...)
{{- end }}

{{ define "dialect/sql/predicate/not" -}}
func(s *sql.Selector) {
p(s.Not())
}
sql.NotPredicates(p)
{{- end }}
23 changes: 3 additions & 20 deletions entc/integration/cascadelete/ent/comment/where.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions entc/integration/cascadelete/ent/post/where.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions entc/integration/cascadelete/ent/user/where.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions entc/integration/config/ent/user/where.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions entc/integration/customid/ent/account/where.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions entc/integration/customid/ent/blob/where.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 3 additions & 20 deletions entc/integration/customid/ent/bloblink/where.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 808edd1

Please sign in to comment.