Skip to content

Commit d50e331

Browse files
authored
Merge pull request #17 from n-r-w/CTE_bool
CTE + Not
2 parents 2b8f553 + a91cde4 commit d50e331

File tree

7 files changed

+385
-8
lines changed

7 files changed

+385
-8
lines changed

README.md

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,15 @@ sq.Sum(subQuery)
144144
Column(sq.Expr("id = ANY(?)", []int{1,2,3}))
145145
```
146146
147-
### Support for `IN` and `NOT IN` clause
147+
### Support for `IN`, `NOT` and `NOT IN` clause
148148
149149
```go
150-
In("id", []int{1, 2, 3})
151-
NotIn("id", subQuery)
150+
In("id", []int{1, 2, 3}) // id=ANY(ARRAY[1,2,3])
151+
NotIn("id", subQuery) // id NOT IN (<subQuery>)
152+
153+
Not(Select("col").From("table")) // NOT (SELECT col FROM table)
154+
// double NOT is removed
155+
Not(Not(Select("col").From("table"))) // SELECT col FROM table
152156
```
153157
154158
### Range function
@@ -233,6 +237,24 @@ Alias("u", "pref").OrderBy("id").
233237
// SELECT SELECT u.id AS pref_id, u.name AS pref_name FROM users u GROUP BY u.id AS pref_id, u.name AS pref_name ORDER BY u.id AS pref_id
234238
```
235239
240+
### CTE support (taken from <https://github.com/joshring/squirrel>)
241+
242+
```go
243+
With("alias").As(
244+
Select("col1").From("table"),
245+
).Select(
246+
Select("col2").From("alias"),
247+
)
248+
// WITH alias AS (SELECT col1 FROM table) SELECT col2 FROM alias
249+
250+
WithRecursive("alias").As(
251+
Select("col1").From("table"),
252+
).Select(
253+
Select("col2").From("alias"),
254+
)
255+
// WITH RECURSIVE alias AS (SELECT col1 FROM table) SELECT col2 FROM alias
256+
```
257+
236258
## Miscellaneous
237259
238260
- Added a linter and fixed all warnings.

case_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,14 @@ func TestSqlTypeNameHelper(t *testing.T) {
215215
}
216216

217217
for _, tt := range tests {
218-
t.Run(tt.name, func(t *testing.T) {
218+
t.Run(tt.name, func(t1 *testing.T) {
219219
got, err := sqlTypeNameHelper(tt.arg)
220220
if (err != nil) != tt.wantErr {
221-
t.Errorf("sqlTypeNameHelper() error = %v, wantErr %v", err, tt.wantErr)
221+
t1.Errorf("sqlTypeNameHelper() error = %v, wantErr %v", err, tt.wantErr)
222222
return
223223
}
224224
if got != tt.want {
225-
t.Errorf("sqlTypeNameHelper() = %v, want %v", got, tt.want)
225+
t1.Errorf("sqlTypeNameHelper() = %v, want %v", got, tt.want)
226226
}
227227
})
228228
}

cte.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package squirrel
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
7+
"github.com/lann/builder"
8+
)
9+
10+
// Common Table Expressions helper
11+
// e.g.
12+
// WITH cte AS (
13+
// ...
14+
// ), cte_2 AS (
15+
// ...
16+
// )
17+
// SELECT ... FROM cte ... cte_2;
18+
19+
type commonTableExpressionsData struct {
20+
PlaceholderFormat PlaceholderFormat
21+
Recursive bool
22+
CurrentCteName string
23+
Ctes []Sqlizer
24+
Statement Sqlizer
25+
}
26+
27+
func (d *commonTableExpressionsData) toSql() (sqlStr string, args []any, err error) {
28+
if len(d.Ctes) == 0 {
29+
err = fmt.Errorf("common table expressions statements must have at least one label and subquery")
30+
return "", nil, err
31+
}
32+
33+
if d.Statement == nil {
34+
err = fmt.Errorf("common table expressions must one of the following final statement: (select, insert, replace, update, delete)")
35+
return "", nil, err
36+
}
37+
38+
sql := &bytes.Buffer{}
39+
40+
_, _ = sql.WriteString("WITH ")
41+
if d.Recursive {
42+
_, _ = sql.WriteString("RECURSIVE ")
43+
}
44+
45+
args, err = appendToSql(d.Ctes, sql, ", ", args)
46+
if err != nil {
47+
return "", nil, err
48+
}
49+
50+
_, _ = sql.WriteString(" ")
51+
args, err = appendToSql([]Sqlizer{d.Statement}, sql, "", args)
52+
if err != nil {
53+
return "", nil, err
54+
}
55+
56+
sqlStr = sql.String()
57+
return sqlStr, args, nil
58+
}
59+
60+
func (d *commonTableExpressionsData) ToSql() (sql string, args []any, err error) {
61+
return d.toSql()
62+
}
63+
64+
// Builder
65+
66+
// CommonTableExpressionsBuilder builds CTE (Common Table Expressions) SQL statements.
67+
type CommonTableExpressionsBuilder builder.Builder
68+
69+
func init() {
70+
builder.Register(CommonTableExpressionsBuilder{}, commonTableExpressionsData{})
71+
}
72+
73+
// Format methods
74+
75+
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
76+
// query.
77+
func (b CommonTableExpressionsBuilder) PlaceholderFormat(f PlaceholderFormat) CommonTableExpressionsBuilder {
78+
return builder.Set(b, "PlaceholderFormat", f).(CommonTableExpressionsBuilder)
79+
}
80+
81+
// SQL methods
82+
83+
// ToSql builds the query into a SQL string and bound args.
84+
func (b CommonTableExpressionsBuilder) ToSql() (string, []any, error) {
85+
data := builder.GetStruct(b).(commonTableExpressionsData)
86+
return data.ToSql()
87+
}
88+
89+
// MustSql builds the query into a SQL string and bound args.
90+
// It panics if there are any errors.
91+
func (b CommonTableExpressionsBuilder) MustSql() (string, []any) {
92+
sql, args, err := b.ToSql()
93+
if err != nil {
94+
panic(err)
95+
}
96+
return sql, args
97+
}
98+
99+
func (b CommonTableExpressionsBuilder) Recursive(recursive bool) CommonTableExpressionsBuilder {
100+
return builder.Set(b, "Recursive", recursive).(CommonTableExpressionsBuilder)
101+
}
102+
103+
// Cte starts a new cte
104+
func (b CommonTableExpressionsBuilder) Cte(cte string) CommonTableExpressionsBuilder {
105+
return builder.Set(b, "CurrentCteName", cte).(CommonTableExpressionsBuilder)
106+
}
107+
108+
// As sets the expression for the Cte
109+
func (b CommonTableExpressionsBuilder) As(as SelectBuilder) CommonTableExpressionsBuilder {
110+
data := builder.GetStruct(b).(commonTableExpressionsData)
111+
return builder.Append(b, "Ctes", cteExpr{as, data.CurrentCteName}).(CommonTableExpressionsBuilder)
112+
}
113+
114+
// Select finalizes the CommonTableExpressionsBuilder with a SELECT
115+
func (b CommonTableExpressionsBuilder) Select(statement SelectBuilder) CommonTableExpressionsBuilder {
116+
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
117+
}
118+
119+
// Insert finalizes the CommonTableExpressionsBuilder with an INSERT
120+
func (b CommonTableExpressionsBuilder) Insert(statement InsertBuilder) CommonTableExpressionsBuilder {
121+
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
122+
}
123+
124+
// Replace finalizes the CommonTableExpressionsBuilder with a REPLACE
125+
func (b CommonTableExpressionsBuilder) Replace(statement InsertBuilder) CommonTableExpressionsBuilder {
126+
return b.Insert(statement)
127+
}
128+
129+
// Update finalizes the CommonTableExpressionsBuilder with an UPDATE
130+
func (b CommonTableExpressionsBuilder) Update(statement UpdateBuilder) CommonTableExpressionsBuilder {
131+
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
132+
}
133+
134+
// Delete finalizes the CommonTableExpressionsBuilder with a DELETE
135+
func (b CommonTableExpressionsBuilder) Delete(statement DeleteBuilder) CommonTableExpressionsBuilder {
136+
return builder.Set(b, "Statement", statement).(CommonTableExpressionsBuilder)
137+
}

cte_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package squirrel
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestWithAsQuery_OneSubquery(t *testing.T) {
10+
w := With("lab").As(
11+
Select("col").From("tab").
12+
Where("simple AND NOT hard"),
13+
).Select(
14+
Select("col").
15+
From("lab"),
16+
)
17+
q, _, err := w.ToSql()
18+
assert.NoError(t, err)
19+
20+
expectedSql := "WITH lab AS (SELECT col FROM tab WHERE simple AND NOT hard) SELECT col FROM lab"
21+
assert.Equal(t, expectedSql, q)
22+
23+
w = WithRecursive("lab").As(
24+
Select("col").From("tab").
25+
Where("simple").
26+
Where("NOT hard"),
27+
).Select(Select("col").
28+
From("lab"),
29+
)
30+
q, _, err = w.ToSql()
31+
assert.NoError(t, err)
32+
33+
expectedSql = "WITH RECURSIVE lab AS (" +
34+
"SELECT col FROM tab WHERE simple AND NOT hard" +
35+
") " +
36+
"SELECT col FROM lab"
37+
assert.Equal(t, expectedSql, q)
38+
}
39+
40+
func TestWithAsQuery_TwoSubqueries(t *testing.T) {
41+
w := With("lab_1").As(
42+
Select("col_1", "col_common").From("tab_1").
43+
Where("simple").
44+
Where("NOT hard"),
45+
).Cte("lab_2").As(
46+
Select("col_2", "col_common").From("tab_2"),
47+
).Select(Select("col_1", "col_2", "col_common").
48+
From("lab_1").Join("lab_2 ON lab_1.col_common = lab_2.col_common"),
49+
)
50+
q, _, err := w.ToSql()
51+
assert.NoError(t, err)
52+
53+
expectedSql := "WITH lab_1 AS (" +
54+
"SELECT col_1, col_common FROM tab_1 WHERE simple AND NOT hard" +
55+
"), lab_2 AS (" +
56+
"SELECT col_2, col_common FROM tab_2" +
57+
") " +
58+
"SELECT col_1, col_2, col_common FROM lab_1 JOIN lab_2 ON lab_1.col_common = lab_2.col_common"
59+
assert.Equal(t, expectedSql, q)
60+
}
61+
62+
func TestWithAsQuery_ManySubqueries(t *testing.T) {
63+
w := With("lab_1").As(
64+
Select("col_1", "col_common").From("tab_1").
65+
Where("simple").
66+
Where("NOT hard"),
67+
).Cte("lab_2").As(
68+
Select("col_2", "col_common").From("tab_2"),
69+
).Cte("lab_3").As(
70+
Select("col_3", "col_common").From("tab_3"),
71+
).Cte("lab_4").As(
72+
Select("col_4", "col_common").From("tab_4"),
73+
).Select(
74+
Select("col_1", "col_2", "col_3", "col_4", "col_common").
75+
From("lab_1").Join("lab_2 ON lab_1.col_common = lab_2.col_common").
76+
Join("lab_3 ON lab_1.col_common = lab_3.col_common").
77+
Join("lab_4 ON lab_1.col_common = lab_4.col_common"),
78+
)
79+
q, _, err := w.ToSql()
80+
assert.NoError(t, err)
81+
82+
expectedSql := "WITH lab_1 AS (" +
83+
"SELECT col_1, col_common FROM tab_1 WHERE simple AND NOT hard" +
84+
"), lab_2 AS (" +
85+
"SELECT col_2, col_common FROM tab_2" +
86+
"), lab_3 AS (" +
87+
"SELECT col_3, col_common FROM tab_3" +
88+
"), lab_4 AS (" +
89+
"SELECT col_4, col_common FROM tab_4" +
90+
") " +
91+
"SELECT col_1, col_2, col_3, col_4, col_common FROM lab_1 JOIN lab_2 ON lab_1.col_common = lab_2.col_common JOIN lab_3 ON lab_1.col_common = lab_3.col_common JOIN lab_4 ON lab_1.col_common = lab_4.col_common"
92+
assert.Equal(t, expectedSql, q)
93+
}
94+
95+
func TestWithAsQuery_Insert(t *testing.T) {
96+
w := With("lab").As(
97+
Select("col").From("tab").
98+
Where("simple").
99+
Where("NOT hard"),
100+
).Insert(Insert("ins_tab").Columns("ins_col").Select(Select("col").From("lab")))
101+
q, _, err := w.ToSql()
102+
assert.NoError(t, err)
103+
104+
expectedSql := "WITH lab AS (" +
105+
"SELECT col FROM tab WHERE simple AND NOT hard" +
106+
") " +
107+
"INSERT INTO ins_tab (ins_col) SELECT col FROM lab"
108+
assert.Equal(t, expectedSql, q)
109+
}
110+
111+
func TestWithAsQuery_Update(t *testing.T) {
112+
w := With("lab").As(
113+
Select("col", "common_col").From("tab").
114+
Where("simple").
115+
Where("NOT hard"),
116+
).Update(
117+
Update("upd_tab, lab").
118+
Set("upd_col", Expr("lab.col")).
119+
Where("common_col = lab.common_col"),
120+
)
121+
122+
q, _, err := w.ToSql()
123+
assert.NoError(t, err)
124+
125+
expectedSql := "WITH lab AS (" +
126+
"SELECT col, common_col FROM tab WHERE simple AND NOT hard" +
127+
") " +
128+
"UPDATE upd_tab, lab SET upd_col = lab.col WHERE common_col = lab.common_col"
129+
130+
assert.Equal(t, expectedSql, q)
131+
}

expr.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ type aliasExpr struct {
127127
// Ex:
128128
//
129129
// .Column(Alias(caseStmt, "case_column"))
130-
func Alias(e Sqlizer, alias string) aliasExpr {
131-
return aliasExpr{e, alias}
130+
func Alias(e Sqlizer, a string) aliasExpr {
131+
return aliasExpr{e, a}
132132
}
133133

134134
func (e aliasExpr) ToSql() (sql string, args []any, err error) {
@@ -741,3 +741,45 @@ func clearEmptyValue(v any) any {
741741

742742
return nil
743743
}
744+
745+
type cteExpr struct {
746+
expr Sqlizer
747+
cte string
748+
}
749+
750+
// Cte allows to define CTE (Common Table Expressions) in SQL query
751+
func Cte(e Sqlizer, cte string) cteExpr {
752+
return cteExpr{e, cte}
753+
}
754+
755+
// ToSql builds the query into a SQL string and bound args.
756+
func (e cteExpr) ToSql() (sql string, args []any, err error) {
757+
sql, args, err = e.expr.ToSql()
758+
if err == nil {
759+
sql = fmt.Sprintf("%s AS (%s)", e.cte, sql)
760+
}
761+
return
762+
}
763+
764+
type notExpr struct {
765+
expr Sqlizer
766+
}
767+
768+
// ToSql builds the query into a SQL string and bound args.
769+
func (e notExpr) ToSql() (sql string, args []any, err error) {
770+
sql, args, err = e.expr.ToSql()
771+
if err == nil {
772+
sql = fmt.Sprintf("NOT (%s)", sql)
773+
}
774+
return
775+
}
776+
777+
// Not is a helper function to negate a condition.
778+
func Not(e Sqlizer) Sqlizer {
779+
// check nested NOT
780+
if n, ok := e.(notExpr); ok {
781+
return n.expr
782+
}
783+
784+
return notExpr{e}
785+
}

0 commit comments

Comments
 (0)