Skip to content

Commit a3fc2d8

Browse files
authored
Merge pull request #463 from go-jet/columnlist-set
Add support for assigning one ColumnList to another in INSERT and UPDATE queries.
2 parents c9e6fb1 + 34ee39c commit a3fc2d8

File tree

8 files changed

+171
-45
lines changed

8 files changed

+171
-45
lines changed

cmd/jet/version.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
package main
22

3-
const version = "v2.12.0"
3+
const version = "v2.13.0"

internal/jet/column_assigment.go

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,20 @@
11
package jet
22

3-
// ColumnAssigment is interface wrapper around column assigment
3+
// ColumnAssigment is interface wrapper around column assignment
44
type ColumnAssigment interface {
55
Serializer
6-
isColumnAssigment()
6+
isColumnAssignment()
77
}
88

99
type columnAssigmentImpl struct {
10-
column ColumnSerializer
11-
expression Expression
10+
column ColumnSerializer
11+
toAssign Serializer
1212
}
1313

14-
func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment {
15-
return &columnAssigmentImpl{
16-
column: serializer,
17-
expression: expression,
18-
}
19-
}
20-
21-
func (a columnAssigmentImpl) isColumnAssigment() {}
14+
func (a columnAssigmentImpl) isColumnAssignment() {}
2215

2316
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
2417
a.column.serialize(statement, out, ShortName.WithFallTrough(options)...)
2518
out.WriteString("=")
26-
a.expression.serialize(statement, out, FallTrough(options)...)
19+
a.toAssign.serialize(statement, out, FallTrough(options)...)
2720
}

internal/jet/column_list.go

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,60 @@
11
package jet
22

3+
import "fmt"
4+
35
// ColumnList is a helper type to support list of columns as single projection
46
type ColumnList []ColumnExpression
57

6-
// SET creates column assigment for each column in column list. expression should be created by ROW function
8+
func (cl ColumnList) isExpressionOrColumnList() {}
9+
10+
// SET creates a column assignment from the current ColumnList using the provided expression.
11+
// This assignment can be used in INSERT queries (e.g., to set columns on conflict) or in UPDATE queries
12+
// (e.g., to assign new values to columns).
13+
//
14+
// The expression can be:
15+
// - Another ColumnList: It must have the same length as the current ColumnList and each column must match by name
16+
// - A ROW expression containing values.
17+
// - A SELECT statement that returns a matching column list structure.
18+
//
19+
// Examples:
720
//
8-
// Link.UPDATE().
9-
// SET(Link.MutableColumns.SET(ROW(String("github.com"), Bool(false))).
10-
// WHERE(Link.ID.EQ(Int(0)))
11-
func (cl ColumnList) SET(expression Expression) ColumnAssigment {
21+
// Link.AllColumns.SET(ROW(String("github.com"), Bool(false)))
22+
//
23+
// Link.MutableColumns.SET(Link.EXCLUDED.MutableColumns)
24+
//
25+
// Link.MutableColumns.SET(
26+
// SELECT(Link.MutableColumns).
27+
// FROM(Link).
28+
// WHERE(Link.ID.EQ(Int(200))),
29+
// )
30+
func (cl ColumnList) SET(toAssignExp expressionOrColumnList) ColumnAssigment {
31+
32+
if toAssign, ok := toAssignExp.(ColumnList); ok {
33+
if len(cl) != len(toAssign) {
34+
panic(fmt.Sprintf("jet: column list length mismatch: expected %d columns, got %d", len(cl), len(toAssign)))
35+
}
36+
37+
var ret columnListAssigment
38+
39+
for i, column := range cl {
40+
if column.Name() != toAssign[i].Name() {
41+
panic(fmt.Sprintf("jet: column name mismatch at index %d: expected column '%s', got '%s'",
42+
i, column.Name(), toAssign[i].Name(),
43+
))
44+
}
45+
46+
ret = append(ret, columnAssigmentImpl{
47+
column: column,
48+
toAssign: toAssign[i],
49+
})
50+
}
51+
52+
return ret
53+
}
54+
1255
return columnAssigmentImpl{
13-
column: cl,
14-
expression: expression,
56+
column: cl,
57+
toAssign: toAssignExp,
1558
}
1659
}
1760

internal/jet/column_list_assigment.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package jet
2+
3+
type expressionOrColumnList interface {
4+
Serializer
5+
isExpressionOrColumnList()
6+
}
7+
8+
type columnListAssigment []ColumnAssigment
9+
10+
func (c columnListAssigment) isColumnAssignment() {}
11+
12+
func (c columnListAssigment) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
13+
for i, columnAssigment := range c {
14+
if i > 0 {
15+
out.WriteString(",")
16+
out.NewLine()
17+
}
18+
19+
columnAssigment.serialize(statement, out, options...)
20+
}
21+
}

internal/jet/column_list_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package jet
2+
3+
import (
4+
"github.com/stretchr/testify/require"
5+
"testing"
6+
)
7+
8+
func TestColumnList_SET(t *testing.T) {
9+
columnList1 := ColumnList{IntegerColumn("id"), StringColumn("Name"), BoolColumn("active")}
10+
columnList2 := ColumnList{IntegerColumn("id"), StringColumn("Name"), BoolColumn("active")}
11+
12+
columnList1.SET(columnList2)
13+
14+
columnList3 := ColumnList{IntegerColumn("id"), StringColumn("Name")}
15+
16+
require.PanicsWithValue(t, "jet: column list length mismatch: expected 2 columns, got 3", func() {
17+
columnList3.SET(columnList1)
18+
})
19+
20+
columnList4 := ColumnList{IntegerColumn("id"), StringColumn("FullName"), BoolColumn("active")}
21+
22+
require.PanicsWithValue(t, "jet: column name mismatch at index 1: expected column 'Name', got 'FullName'", func() {
23+
columnList1.SET(columnList4)
24+
})
25+
}

internal/jet/column_types.go

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
2828

2929
func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment {
3030
return columnAssigmentImpl{
31-
column: i,
32-
expression: boolExp,
31+
column: i,
32+
toAssign: boolExp,
3333
}
3434
}
3535

@@ -72,8 +72,8 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
7272

7373
func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment {
7474
return columnAssigmentImpl{
75-
column: i,
76-
expression: floatExp,
75+
column: i,
76+
toAssign: floatExp,
7777
}
7878
}
7979

@@ -117,8 +117,8 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
117117

118118
func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment {
119119
return columnAssigmentImpl{
120-
column: i,
121-
expression: intExp,
120+
column: i,
121+
toAssign: intExp,
122122
}
123123
}
124124

@@ -163,8 +163,8 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
163163

164164
func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment {
165165
return columnAssigmentImpl{
166-
column: i,
167-
expression: stringExp,
166+
column: i,
167+
toAssign: stringExp,
168168
}
169169
}
170170

@@ -208,8 +208,8 @@ func (i *blobColumnImpl) From(subQuery SelectTable) ColumnBlob {
208208

209209
func (i *blobColumnImpl) SET(blobExp BlobExpression) ColumnAssigment {
210210
return columnAssigmentImpl{
211-
column: i,
212-
expression: blobExp,
211+
column: i,
212+
toAssign: blobExp,
213213
}
214214
}
215215

@@ -252,8 +252,8 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
252252

253253
func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment {
254254
return columnAssigmentImpl{
255-
column: i,
256-
expression: timeExp,
255+
column: i,
256+
toAssign: timeExp,
257257
}
258258
}
259259

@@ -295,8 +295,8 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
295295

296296
func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment {
297297
return columnAssigmentImpl{
298-
column: i,
299-
expression: timezExp,
298+
column: i,
299+
toAssign: timezExp,
300300
}
301301
}
302302

@@ -339,8 +339,8 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
339339

340340
func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment {
341341
return columnAssigmentImpl{
342-
column: i,
343-
expression: timestampExp,
342+
column: i,
343+
toAssign: timestampExp,
344344
}
345345
}
346346

@@ -383,8 +383,8 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
383383

384384
func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment {
385385
return columnAssigmentImpl{
386-
column: i,
387-
expression: timestampzExp,
386+
column: i,
387+
toAssign: timestampzExp,
388388
}
389389
}
390390

@@ -427,8 +427,8 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
427427

428428
func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment {
429429
return columnAssigmentImpl{
430-
column: i,
431-
expression: dateExp,
430+
column: i,
431+
toAssign: dateExp,
432432
}
433433
}
434434

@@ -460,8 +460,8 @@ type intervalColumnImpl struct {
460460

461461
func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
462462
return columnAssigmentImpl{
463-
column: i,
464-
expression: intervalExp,
463+
column: i,
464+
toAssign: intervalExp,
465465
}
466466
}
467467

@@ -516,8 +516,8 @@ func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] {
516516

517517
func (i *rangeColumnImpl[T]) SET(rangeExp Range[T]) ColumnAssigment {
518518
return columnAssigmentImpl{
519-
column: i,
520-
expression: rangeExp,
519+
column: i,
520+
toAssign: rangeExp,
521521
}
522522
}
523523

internal/jet/expression.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ type Expression interface {
99
Projection
1010
GroupByClause
1111
OrderByClause
12+
expressionOrColumnList
1213

1314
serializeForJsonValue(statement StatementType, out *SQLBuilder)
1415
setRoot(root Expression)
@@ -37,6 +38,8 @@ type ExpressionInterfaceImpl struct {
3738
Root Expression
3839
}
3940

41+
func (e *ExpressionInterfaceImpl) isExpressionOrColumnList() {}
42+
4043
func (e *ExpressionInterfaceImpl) setRoot(root Expression) {
4144
e.Root = root
4245
}

tests/postgres/insert_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,47 @@ RETURNING link.id AS "link.id",
196196
testutils.AssertExecAndRollback(t, stmt, db, 2)
197197
})
198198

199+
t.Run("do update column list", func(t *testing.T) {
200+
stmt := Link.INSERT().
201+
VALUES(1, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
202+
ON_CONFLICT(Link.ID).DO_UPDATE(
203+
SET(
204+
Link.MutableColumns.SET(Link.EXCLUDED.MutableColumns),
205+
),
206+
).RETURNING(Link.AllColumns)
207+
208+
testutils.AssertDebugStatementSql(t, stmt, `
209+
INSERT INTO test_sample.link
210+
VALUES (1, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT)
211+
ON CONFLICT (id) DO UPDATE
212+
SET url = excluded.url,
213+
name = excluded.name,
214+
description = excluded.description
215+
RETURNING link.id AS "link.id",
216+
link.url AS "link.url",
217+
link.name AS "link.name",
218+
link.description AS "link.description";
219+
`)
220+
221+
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
222+
var dest []model.Link
223+
224+
err := stmt.QueryContext(ctx, tx, &dest)
225+
require.NoError(t, err)
226+
227+
testutils.AssertJSON(t, dest, `
228+
[
229+
{
230+
"ID": 1,
231+
"URL": "http://www.postgresqltutorial.com",
232+
"Name": "PostgreSQL Tutorial",
233+
"Description": null
234+
}
235+
]
236+
`)
237+
})
238+
})
239+
199240
t.Run("do update complex", func(t *testing.T) {
200241
skipForCockroachDB(t) // does not support ROW
201242

0 commit comments

Comments
 (0)