Skip to content

Commit

Permalink
fix #169: support IS [NOT] DISTINCT FROM operator
Browse files Browse the repository at this point in the history
  • Loading branch information
huandu committed Sep 23, 2024
1 parent fa64168 commit 975bcfd
Show file tree
Hide file tree
Showing 3 changed files with 593 additions and 332 deletions.
246 changes: 144 additions & 102 deletions args.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Args struct {
// The default flavor used by `Args#Compile`
Flavor Flavor

args []interface{}
argValues []interface{}
namedArgs map[string]int
sqlNamedArgs map[string]int
onlyNamed bool
Expand Down Expand Up @@ -47,7 +47,7 @@ func (args *Args) Add(arg interface{}) string {
}

func (args *Args) add(arg interface{}) int {
idx := len(args.args)
idx := len(args.argValues)

switch a := arg.(type) {
case sql.NamedArg:
Expand All @@ -56,7 +56,7 @@ func (args *Args) add(arg interface{}) int {
}

if p, ok := args.sqlNamedArgs[a.Name]; ok {
arg = args.args[p]
arg = args.argValues[p]
break
}

Expand All @@ -67,7 +67,7 @@ func (args *Args) add(arg interface{}) int {
}

if p, ok := args.namedArgs[a.name]; ok {
arg = args.args[p]
arg = args.argValues[p]
break
}

Expand All @@ -77,7 +77,7 @@ func (args *Args) add(arg interface{}) int {
return idx
}

args.args = append(args.args, arg)
args.argValues = append(args.argValues, arg)
return idx
}

Expand All @@ -97,55 +97,58 @@ func (args *Args) Compile(format string, initialValue ...interface{}) (query str
//
// See doc for `Compile` to learn details.
func (args *Args) CompileWithFlavor(format string, flavor Flavor, initialValue ...interface{}) (query string, values []interface{}) {
buf := newStringBuilder()
idx := strings.IndexRune(format, '$')
offset := 0
values = initialValue
ctx := &argsCompileContext{
stringBuilder: newStringBuilder(),
Flavor: flavor,
Values: initialValue,
}

if flavor == invalidFlavor {
flavor = DefaultFlavor
if ctx.Flavor == invalidFlavor {
ctx.Flavor = DefaultFlavor
}

for idx >= 0 && len(format) > 0 {
if idx > 0 {
buf.WriteString(format[:idx])
ctx.WriteString(format[:idx])
}

format = format[idx+1:]

// Treat the $ at the end of format is a normal $ rune.
if len(format) == 0 {
buf.WriteRune('$')
ctx.WriteRune('$')
break
}

if r := format[0]; r == '$' {
buf.WriteRune('$')
ctx.WriteRune('$')
format = format[1:]
} else if r == '{' {
format, values = args.compileNamed(buf, flavor, format, values)
format = args.compileNamed(ctx, format)
} else if !args.onlyNamed && '0' <= r && r <= '9' {
format, values, offset = args.compileDigits(buf, flavor, format, values, offset)
format, offset = args.compileDigits(ctx, format, offset)
} else if !args.onlyNamed && r == '?' {
format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset)
format, offset = args.compileSuccessive(ctx, format[1:], offset)
} else {
// For unknown $ expression format, treat it as a normal $ rune.
buf.WriteRune('$')
ctx.WriteRune('$')
}

idx = strings.IndexRune(format, '$')
}

if len(format) > 0 {
buf.WriteString(format)
ctx.WriteString(format)
}

query = buf.String()
values = args.mergeSQLNamedArgs(values)
query = ctx.String()
values = args.mergeSQLNamedArgs(ctx)
return
}

func (args *Args) compileNamed(buf *stringBuilder, flavor Flavor, format string, values []interface{}) (string, []interface{}) {
func (args *Args) compileNamed(ctx *argsCompileContext, format string) string {
i := 1

for ; i < len(format) && format[i] != '}'; i++ {
Expand All @@ -154,20 +157,20 @@ func (args *Args) compileNamed(buf *stringBuilder, flavor Flavor, format string,

// Invalid $ format. Ignore it.
if i == len(format) {
return format, values
return format
}

name := format[1:i]
format = format[i+1:]

if p, ok := args.namedArgs[name]; ok {
format, values, _ = args.compileSuccessive(buf, flavor, format, values, p)
format, _ = args.compileSuccessive(ctx, format, p)
}

return format, values
return format
}

func (args *Args) compileDigits(buf *stringBuilder, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
func (args *Args) compileDigits(ctx *argsCompileContext, format string, offset int) (string, int) {
i := 1

for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
Expand All @@ -178,91 +181,37 @@ func (args *Args) compileDigits(buf *stringBuilder, flavor Flavor, format string
format = format[i:]

if pointer, err := strconv.Atoi(digits); err == nil {
return args.compileSuccessive(buf, flavor, format, values, pointer)
return args.compileSuccessive(ctx, format, pointer)
}

return format, values, offset
return format, offset
}

func (args *Args) compileSuccessive(buf *stringBuilder, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
if offset >= len(args.args) {
return format, values, offset
func (args *Args) compileSuccessive(ctx *argsCompileContext, format string, offset int) (string, int) {
if offset >= len(args.argValues) {
return format, offset
}

arg := args.args[offset]
values = args.compileArg(buf, flavor, values, arg)

return format, values, offset + 1
}

func (args *Args) compileArg(buf *stringBuilder, flavor Flavor, values []interface{}, arg interface{}) []interface{} {
switch a := arg.(type) {
case Builder:
var s string
s, values = a.BuildWithFlavor(flavor, values...)
buf.WriteString(s)
case sql.NamedArg:
buf.WriteRune('@')
buf.WriteString(a.Name)
case rawArgs:
buf.WriteString(a.expr)
case listArgs:
if a.isTuple {
buf.WriteRune('(')
}

if len(a.args) > 0 {
values = args.compileArg(buf, flavor, values, a.args[0])
}

for i := 1; i < len(a.args); i++ {
buf.WriteString(", ")
values = args.compileArg(buf, flavor, values, a.args[i])
}

if a.isTuple {
buf.WriteRune(')')
}

default:
switch flavor {
case MySQL, SQLite, CQL, ClickHouse, Presto, Informix:
buf.WriteRune('?')
case PostgreSQL:
fmt.Fprintf(buf, "$%d", len(values)+1)
case SQLServer:
fmt.Fprintf(buf, "@p%d", len(values)+1)
case Oracle:
fmt.Fprintf(buf, ":%d", len(values)+1)
default:
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
}

namedValues := parseNamedArgs(values)
arg := args.argValues[offset]
ctx.WriteValue(arg)

if n := len(namedValues); n == 0 {
values = append(values, arg)
} else {
index := len(values) - n
values = append(values[:index+1], namedValues...)
values[index] = arg
}
}

return values
return format, offset + 1
}

func (args *Args) mergeSQLNamedArgs(values []interface{}) []interface{} {
if len(args.sqlNamedArgs) == 0 {
return values
func (args *Args) mergeSQLNamedArgs(ctx *argsCompileContext) []interface{} {
if len(args.sqlNamedArgs) == 0 && len(ctx.NamedArgs) == 0 {
return ctx.Values
}

namedValues := parseNamedArgs(values)
existingNames := make(map[string]struct{}, len(namedValues))
values := ctx.Values
existingNames := make(map[string]struct{}, len(ctx.NamedArgs))

for _, v := range namedValues {
if a, ok := v.(sql.NamedArg); ok {
existingNames[a.Name] = struct{}{}
// Add all named args to values.
// Remove duplicated named args in this step.
for _, arg := range ctx.NamedArgs {
if _, ok := existingNames[arg.Name]; !ok {
existingNames[arg.Name] = struct{}{}
values = append(values, arg)
}
}

Expand All @@ -280,19 +229,21 @@ func (args *Args) mergeSQLNamedArgs(values []interface{}) []interface{} {
sort.Ints(ints)

for _, i := range ints {
values = append(values, args.args[i])
values = append(values, args.argValues[i])
}

return values
}

func parseNamedArgs(initialValue []interface{}) (namedValues []interface{}) {
func parseNamedArgs(initialValue []interface{}) (values []interface{}, namedValues []sql.NamedArg) {
if len(initialValue) == 0 {
return nil
values = initialValue
return
}

// sql.NamedArgs must be placed at the end of the initial value.
i := len(initialValue)
size := len(initialValue)
i := size

for ; i > 0; i-- {
switch initialValue[i-1].(type) {
Expand All @@ -303,6 +254,97 @@ func parseNamedArgs(initialValue []interface{}) (namedValues []interface{}) {
break
}

namedValues = initialValue[i:]
if i == size {
values = initialValue
return
}

values = initialValue[:i]
namedValues = make([]sql.NamedArg, 0, size-i)

for ; i < size; i++ {
namedValues = append(namedValues, initialValue[i].(sql.NamedArg))
}

return
}

type argsCompileContext struct {
*stringBuilder

Flavor Flavor
Values []interface{}
NamedArgs []sql.NamedArg
}

func (ctx *argsCompileContext) WriteValue(arg interface{}) {
switch a := arg.(type) {
case Builder:
s, values := a.BuildWithFlavor(ctx.Flavor, ctx.Values...)
ctx.WriteString(s)

// Add all values to ctx.
// Named args must be located at the end of values.
values, namedArgs := parseNamedArgs(values)
ctx.Values = values
ctx.NamedArgs = append(ctx.NamedArgs, namedArgs...)

case sql.NamedArg:
ctx.WriteRune('@')
ctx.WriteString(a.Name)
ctx.NamedArgs = append(ctx.NamedArgs, a)

case rawArgs:
ctx.WriteString(a.expr)

case listArgs:
if a.isTuple {
ctx.WriteRune('(')
}

if len(a.args) > 0 {
ctx.WriteValue(a.args[0])
}

for i := 1; i < len(a.args); i++ {
ctx.WriteString(", ")
ctx.WriteValue(a.args[i])
}

if a.isTuple {
ctx.WriteRune(')')
}

case condBuilder:
a.Builder(ctx)

default:
switch ctx.Flavor {
case MySQL, SQLite, CQL, ClickHouse, Presto, Informix:
ctx.WriteRune('?')
case PostgreSQL:
fmt.Fprintf(ctx, "$%d", len(ctx.Values)+1)
case SQLServer:
fmt.Fprintf(ctx, "@p%d", len(ctx.Values)+1)
case Oracle:
fmt.Fprintf(ctx, ":%d", len(ctx.Values)+1)
default:
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", ctx.Flavor, int(ctx.Flavor)))
}

ctx.Values = append(ctx.Values, arg)
}
}

func (ctx *argsCompileContext) WriteValues(values []interface{}, sep string) {
if len(values) == 0 {
return
}

ctx.WriteValue(values[0])

for _, v := range values[1:] {
ctx.WriteString(sep)
ctx.WriteValue(v)
}
}
Loading

0 comments on commit 975bcfd

Please sign in to comment.