diff --git a/internal/endtoend/testdata/update_set/sqlite/go/query.sql.go b/internal/endtoend/testdata/update_set/sqlite/go/query.sql.go index 071e3e4a62..d13a3ceda2 100644 --- a/internal/endtoend/testdata/update_set/sqlite/go/query.sql.go +++ b/internal/endtoend/testdata/update_set/sqlite/go/query.sql.go @@ -22,3 +22,17 @@ func (q *Queries) UpdateSet(ctx context.Context, arg UpdateSetParams) error { _, err := q.db.ExecContext(ctx, updateSet, arg.Name, arg.Slug) return err } + +const updateSetQuoted = `-- name: UpdateSetQuoted :exec +UPDATE "foo" SET "name" = ? WHERE "slug" = ? +` + +type UpdateSetQuotedParams struct { + Name string + Slug string +} + +func (q *Queries) UpdateSetQuoted(ctx context.Context, arg UpdateSetQuotedParams) error { + _, err := q.db.ExecContext(ctx, updateSetQuoted, arg.Name, arg.Slug) + return err +} diff --git a/internal/endtoend/testdata/update_set/sqlite/query.sql b/internal/endtoend/testdata/update_set/sqlite/query.sql index 0f6603f503..3964a2ffe1 100644 --- a/internal/endtoend/testdata/update_set/sqlite/query.sql +++ b/internal/endtoend/testdata/update_set/sqlite/query.sql @@ -1,2 +1,5 @@ /* name: UpdateSet :exec */ UPDATE foo SET name = ? WHERE slug = ?; + +/* name: UpdateSetQuoted :exec */ +UPDATE "foo" SET "name" = ? WHERE "slug" = ?; diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index e86dd8ac82..14b211434c 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -1021,6 +1021,8 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast } type Update_stmt interface { + node + Qualified_table_name() parser.IQualified_table_nameContext GetStart() antlr.Token AllColumn_name() []parser.IColumn_nameContext @@ -1034,50 +1036,66 @@ func (c *cc) convertUpdate_stmtContext(n Update_stmt) ast.Node { return nil } - relations := &ast.List{} - tableName := n.Qualified_table_name().GetText() - rel := ast.RangeVar{ - Relname: &tableName, - Location: n.GetStart().GetStart(), - } - relations.Items = append(relations.Items, &rel) + if qualifiedName, ok := n.Qualified_table_name().(*parser.Qualified_table_nameContext); ok { + tableName := identifier(qualifiedName.Table_name().GetText()) + rel := ast.RangeVar{ + Relname: &tableName, + Location: n.GetStart().GetStart(), + } - list := &ast.List{} - for i, col := range n.AllColumn_name() { - colName := identifier(col.GetText()) - target := &ast.ResTarget{ - Name: &colName, - Val: c.convert(n.Expr(i)), + if qualifiedName.Schema_name() != nil { + schemaName := qualifiedName.Schema_name().GetText() + rel.Schemaname = &schemaName } - list.Items = append(list.Items, target) - } - var where ast.Node = nil - if n.WHERE_() != nil { - where = c.convert(n.Expr(len(n.AllExpr()) - 1)) - } + if qualifiedName.Alias() != nil { + alias := qualifiedName.Alias().GetText() + rel.Alias = &ast.Alias{Aliasname: &alias} + } - stmt := &ast.UpdateStmt{ - Relations: relations, - TargetList: list, - WhereClause: where, - FromClause: &ast.List{}, - WithClause: nil, // TODO: support with clause - } - if n, ok := n.(interface { - Returning_clause() parser.IReturning_clauseContext - }); ok { - stmt.ReturningList = c.convertReturning_caluseContext(n.Returning_clause()) - } else { - stmt.ReturningList = c.convertReturning_caluseContext(nil) - } - if n, ok := n.(interface { - Limit_stmt() parser.ILimit_stmtContext - }); ok { - limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt()) - stmt.LimitCount = limitCount + relations := &ast.List{} + + relations.Items = append(relations.Items, &rel) + + list := &ast.List{} + for i, col := range n.AllColumn_name() { + colName := identifier(col.GetText()) + target := &ast.ResTarget{ + Name: &colName, + Val: c.convert(n.Expr(i)), + } + list.Items = append(list.Items, target) + } + + var where ast.Node = nil + if n.WHERE_() != nil { + where = c.convert(n.Expr(len(n.AllExpr()) - 1)) + } + + stmt := &ast.UpdateStmt{ + Relations: relations, + TargetList: list, + WhereClause: where, + FromClause: &ast.List{}, + WithClause: nil, // TODO: support with clause + } + if n, ok := n.(interface { + Returning_clause() parser.IReturning_clauseContext + }); ok { + stmt.ReturningList = c.convertReturning_caluseContext(n.Returning_clause()) + } else { + stmt.ReturningList = c.convertReturning_caluseContext(nil) + } + if n, ok := n.(interface { + Limit_stmt() parser.ILimit_stmtContext + }); ok { + limitCount, _ := c.convertLimit_stmtContext(n.Limit_stmt()) + stmt.LimitCount = limitCount + } + return stmt } - return stmt + + return todo("convertUpdate_stmtContext", n) } func (c *cc) convertBetweenExpr(n *parser.Expr_betweenContext) ast.Node {