Skip to content

Commit

Permalink
fix #163: CTE API refactory; see issue for details
Browse files Browse the repository at this point in the history
  • Loading branch information
huandu committed Sep 8, 2024
1 parent 8cd72ce commit 01acaab
Show file tree
Hide file tree
Showing 13 changed files with 343 additions and 148 deletions.
82 changes: 65 additions & 17 deletions cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ const (
)

// With creates a new CTE builder with default flavor.
func With(tables ...*CTETableBuilder) *CTEBuilder {
func With(tables ...*CTEQueryBuilder) *CTEBuilder {
return DefaultFlavor.NewCTEBuilder().With(tables...)
}

// WithRecursive creates a new recursive CTE builder with default flavor.
func WithRecursive(tables ...*CTETableBuilder) *CTEBuilder {
func WithRecursive(tables ...*CTEQueryBuilder) *CTEBuilder {
return DefaultFlavor.NewCTEBuilder().WithRecursive(tables...)
}

Expand All @@ -28,8 +28,8 @@ func newCTEBuilder() *CTEBuilder {
// CTEBuilder is a CTE (Common Table Expression) builder.
type CTEBuilder struct {
recursive bool
tableNames []string
tableBuilderVars []string
queries []*CTEQueryBuilder
queryBuilderVars []string

args *Args

Expand All @@ -40,24 +40,22 @@ type CTEBuilder struct {
var _ Builder = new(CTEBuilder)

// With sets the CTE name and columns.
func (cteb *CTEBuilder) With(tables ...*CTETableBuilder) *CTEBuilder {
tableNames := make([]string, 0, len(tables))
tableBuilderVars := make([]string, 0, len(tables))
func (cteb *CTEBuilder) With(queries ...*CTEQueryBuilder) *CTEBuilder {
queryBuilderVars := make([]string, 0, len(queries))

for _, table := range tables {
tableNames = append(tableNames, table.TableName())
tableBuilderVars = append(tableBuilderVars, cteb.args.Add(table))
for _, query := range queries {
queryBuilderVars = append(queryBuilderVars, cteb.args.Add(query))
}

cteb.tableNames = tableNames
cteb.tableBuilderVars = tableBuilderVars
cteb.queries = queries
cteb.queryBuilderVars = queryBuilderVars
cteb.marker = cteMarkerAfterWith
return cteb
}

// WithRecursive sets the CTE name and columns and turns on the RECURSIVE keyword.
func (cteb *CTEBuilder) WithRecursive(tables ...*CTETableBuilder) *CTEBuilder {
cteb.With(tables...).recursive = true
func (cteb *CTEBuilder) WithRecursive(queries ...*CTEQueryBuilder) *CTEBuilder {
cteb.With(queries...).recursive = true
return cteb
}

Expand All @@ -67,6 +65,18 @@ func (cteb *CTEBuilder) Select(col ...string) *SelectBuilder {
return sb.With(cteb).Select(col...)
}

// DeleteFrom creates a new DeleteBuilder to build a DELETE statement using this CTE.
func (cteb *CTEBuilder) DeleteFrom(table string) *DeleteBuilder {
db := cteb.args.Flavor.NewDeleteBuilder()
return db.With(cteb).DeleteFrom(table)
}

// Update creates a new UpdateBuilder to build an UPDATE statement using this CTE.
func (cteb *CTEBuilder) Update(table string) *UpdateBuilder {
ub := cteb.args.Flavor.NewUpdateBuilder()
return ub.With(cteb).Update(table)
}

// String returns the compiled CTE string.
func (cteb *CTEBuilder) String() string {
sql, _ := cteb.Build()
Expand All @@ -83,12 +93,12 @@ func (cteb *CTEBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}
buf := newStringBuilder()
cteb.injection.WriteTo(buf, cteMarkerInit)

if len(cteb.tableBuilderVars) > 0 {
if len(cteb.queryBuilderVars) > 0 {
buf.WriteLeadingString("WITH ")
if cteb.recursive {
buf.WriteString("RECURSIVE ")
}
buf.WriteStrings(cteb.tableBuilderVars, ", ")
buf.WriteStrings(cteb.queryBuilderVars, ", ")
}

cteb.injection.WriteTo(buf, cteMarkerAfterWith)
Expand All @@ -110,5 +120,43 @@ func (cteb *CTEBuilder) SQL(sql string) *CTEBuilder {

// TableNames returns all table names in a CTE.
func (cteb *CTEBuilder) TableNames() []string {
return cteb.tableNames
if len(cteb.queryBuilderVars) == 0 {
return nil
}

tableNames := make([]string, 0, len(cteb.queries))

for _, query := range cteb.queries {
tableNames = append(tableNames, query.TableName())
}

return tableNames
}

// tableNamesForSelect returns a list of table names which should be automatically added to FROM clause.
// It's not public, as this feature is designed only for SelectBuilder right now.
func (cteb *CTEBuilder) tableNamesForSelect() []string {
cnt := 0

// It's rare that the ShouldAddToTableList() returns true.
// Count it before allocating any memory for better performance.
for _, query := range cteb.queries {
if query.ShouldAddToTableList() {
cnt++
}
}

if cnt == 0 {
return nil
}

tableNames := make([]string, 0, cnt)

for _, query := range cteb.queries {
if query.ShouldAddToTableList() {
tableNames = append(tableNames, query.TableName())
}
}

return tableNames
}
8 changes: 5 additions & 3 deletions cte_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func ExampleWith() {

func ExampleWithRecursive() {
sb := WithRecursive(
CTETable("source_accounts", "id", "parent_id").As(
CTEQuery("source_accounts", "id", "parent_id").As(
UnionAll(
Select("p.id", "p.parent_id").
From("accounts AS p").
Expand Down Expand Up @@ -85,7 +85,7 @@ func ExampleCTEBuilder() {
func TestCTEBuilder(t *testing.T) {
a := assert.New(t)
cteb := newCTEBuilder()
ctetb := newCTETableBuilder()
ctetb := newCTEQueryBuilder()
cteb.SQL("/* init */")
cteb.With(ctetb)
cteb.SQL("/* after with */")
Expand All @@ -97,6 +97,8 @@ func TestCTEBuilder(t *testing.T) {
ctetb.As(Select("a", "b").From("t"))
ctetb.SQL("/* after table as */")

a.Equal(cteb.TableNames(), []string{ctetb.TableName()})

sql, args := cteb.Build()
a.Equal(sql, "/* init */ WITH /* table init */ t (a, b) /* after table */ AS (SELECT a, b FROM t) /* after table as */ /* after with */")
a.Assert(args == nil)
Expand All @@ -109,7 +111,7 @@ func TestRecursiveCTEBuilder(t *testing.T) {
a := assert.New(t)
cteb := newCTEBuilder()
cteb.recursive = true
ctetb := newCTETableBuilder()
ctetb := newCTEQueryBuilder()
cteb.SQL("/* init */")
cteb.With(ctetb)
cteb.SQL("/* after with */")
Expand Down
135 changes: 135 additions & 0 deletions ctequery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright 2024 Huan Du. All rights reserved.
// Licensed under the MIT license that can be found in the LICENSE file.

package sqlbuilder

const (
cteQueryMarkerInit injectionMarker = iota
cteQueryMarkerAfterTable
cteQueryMarkerAfterAs
)

// CTETable creates a new CTE query builder with default flavor, marking it as a table.
//
// The resulting CTE query can be used in a `SelectBuilder“, where its table name will be
// automatically included in the FROM clause.
func CTETable(name string, cols ...string) *CTEQueryBuilder {
return DefaultFlavor.NewCTEQueryBuilder().AddToTableList().Table(name, cols...)
}

// CTEQuery creates a new CTE query builder with default flavor.
func CTEQuery(name string, cols ...string) *CTEQueryBuilder {
return DefaultFlavor.NewCTEQueryBuilder().Table(name, cols...)
}

func newCTEQueryBuilder() *CTEQueryBuilder {
return &CTEQueryBuilder{
args: &Args{},
injection: newInjection(),
}
}

// CTEQueryBuilder is a builder to build one table in CTE (Common Table Expression).
type CTEQueryBuilder struct {
name string
cols []string
builderVar string

// if true, this query's table name will be automatically added to the table list
// in FROM clause of SELECT statement.
autoAddToTableList bool

args *Args

injection *injection
marker injectionMarker
}

var _ Builder = new(CTEQueryBuilder)

// CTETableBuilder is an alias of CTEQueryBuilder for backward compatibility.
// Deprecated: use CTEQueryBuilder instead.
type CTETableBuilder = CTEQueryBuilder

// Table sets the table name and columns in a CTE table.
func (ctetb *CTEQueryBuilder) Table(name string, cols ...string) *CTEQueryBuilder {
ctetb.name = name
ctetb.cols = cols
ctetb.marker = cteQueryMarkerAfterTable
return ctetb
}

// As sets the builder to select data.
func (ctetb *CTEQueryBuilder) As(builder Builder) *CTEQueryBuilder {
ctetb.builderVar = ctetb.args.Add(builder)
ctetb.marker = cteQueryMarkerAfterAs
return ctetb
}

// AddToTableList sets flag to add table name to table list in FROM clause of SELECT statement.
func (ctetb *CTEQueryBuilder) AddToTableList() *CTEQueryBuilder {
ctetb.autoAddToTableList = true
return ctetb
}

// ShouldAddToTableList returns flag to add table name to table list in FROM clause of SELECT statement.
func (ctetb *CTEQueryBuilder) ShouldAddToTableList() bool {
return ctetb.autoAddToTableList
}

// String returns the compiled CTE string.
func (ctetb *CTEQueryBuilder) String() string {
sql, _ := ctetb.Build()
return sql
}

// Build returns compiled CTE string and args.
func (ctetb *CTEQueryBuilder) Build() (sql string, args []interface{}) {
return ctetb.BuildWithFlavor(ctetb.args.Flavor)
}

// BuildWithFlavor builds a CTE with the specified flavor and initial arguments.
func (ctetb *CTEQueryBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
buf := newStringBuilder()
ctetb.injection.WriteTo(buf, cteQueryMarkerInit)

if ctetb.name != "" {
buf.WriteLeadingString(ctetb.name)

if len(ctetb.cols) > 0 {
buf.WriteLeadingString("(")
buf.WriteStrings(ctetb.cols, ", ")
buf.WriteString(")")
}

ctetb.injection.WriteTo(buf, cteQueryMarkerAfterTable)
}

if ctetb.builderVar != "" {
buf.WriteLeadingString("AS (")
buf.WriteString(ctetb.builderVar)
buf.WriteRune(')')

ctetb.injection.WriteTo(buf, cteQueryMarkerAfterAs)
}

return ctetb.args.CompileWithFlavor(buf.String(), flavor, initialArg...)
}

// SetFlavor sets the flavor of compiled sql.
func (ctetb *CTEQueryBuilder) SetFlavor(flavor Flavor) (old Flavor) {
old = ctetb.args.Flavor
ctetb.args.Flavor = flavor
return
}

// SQL adds an arbitrary sql to current position.
func (ctetb *CTEQueryBuilder) SQL(sql string) *CTEQueryBuilder {
ctetb.injection.SQL(ctetb.marker, sql)
return ctetb
}

// TableName returns the CTE table name.
func (ctetb *CTEQueryBuilder) TableName() string {
return ctetb.name
}
Loading

0 comments on commit 01acaab

Please sign in to comment.