Skip to content

Support revert old column attributes #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ CREATE TABLE user (
//CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`);

println(sql1.StringDown())
//ALTER TABLE `user` MODIFY COLUMN `id` int(11);
//ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime;
//DROP INDEX `idx_name_age` ON `user`;
}
```
10 changes: 6 additions & 4 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,13 @@ CREATE TABLE user (

sql1.Diff(*sql2)
println(sql1.StringUp())
// ALTER TABLE `user` MODIFY COLUMN `id` int(11) AUTO_INCREMENT PRIMARY KEY;
// ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP();
// CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`);
//ALTER TABLE `user` MODIFY COLUMN `id` int(11) AUTO_INCREMENT PRIMARY KEY;
//ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP();
//CREATE UNIQUE INDEX `idx_name_age` ON `user`(`name`, `age`);

println(sql1.StringDown())
// DROP INDEX `idx_name_age` ON `user`;
//ALTER TABLE `user` MODIFY COLUMN `id` int(11);
//ALTER TABLE `user` MODIFY COLUMN `updated_at` datetime;
//DROP INDEX `idx_name_age` ON `user`;
}
```
6 changes: 3 additions & 3 deletions avro/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,19 @@ func getAvroType(col element.Column) interface{} {
"type": "string",
"connect.version": 1,
"connect.parameters": map[string]string{
"allowed": strings.Join(col.MysqlType.Elems, ","),
"allowed": strings.Join(col.CurrentAttr.MysqlType.Elems, ","),
},
"connect.default": "init",
"connect.name": "io.debezium.data.Enum",
}
}

switch col.MysqlType.EvalType() {
switch col.CurrentAttr.MysqlType.EvalType() {
case types.ETInt:
return "int"

case types.ETDecimal:
displayFlen, displayDecimal := col.MysqlType.Flen, col.MysqlType.Decimal
displayFlen, displayDecimal := col.CurrentAttr.MysqlType.Flen, col.CurrentAttr.MysqlType.Decimal
return map[string]interface{}{
"type": "bytes",
"scale": displayDecimal,
Expand Down
88 changes: 65 additions & 23 deletions element/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,34 @@ const (
LowerRestoreFlag = format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameLowercase | format.RestoreNameBackQuotes
)

// Column ...
type Column struct {
Node
type SqlAttr struct {
MysqlType *types.FieldType
PgType *ptypes.T
LiteType *sqlite.Type
Options []*ast.ColumnOption
Comment string
}

// Column ...
type Column struct {
Node

CurrentAttr SqlAttr
PreviousAttr SqlAttr
}

// GetType ...
func (c Column) GetType() byte {
if c.MysqlType != nil {
return c.MysqlType.Tp
if c.CurrentAttr.MysqlType != nil {
return c.CurrentAttr.MysqlType.Tp
}

return 0
}

// HasDefaultValue ...
func (c Column) HasDefaultValue() bool {
for _, opt := range c.Options {
for _, opt := range c.CurrentAttr.Options {
if opt.Tp == ast.ColumnOptionDefaultValue {
return true
}
Expand All @@ -54,7 +60,7 @@ func (c Column) HasDefaultValue() bool {

func (c Column) hashValue() string {
strHash := sql.EscapeSqlName(c.Name)
strHash += c.typeDefinition()
strHash += c.typeDefinition(false)
hash := md5.Sum([]byte(strHash))
return hex.EncodeToString(hash[:])
}
Expand All @@ -71,7 +77,7 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
strSql += strings.Repeat(" ", ident-len(c.Name))
}

strSql += c.definition()
strSql += c.definition(false)

if ident < 0 {
if after != "" {
Expand All @@ -90,10 +96,27 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
return []string{fmt.Sprintf(sql.AlterTableDropColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name))}

case MigrateModifyAction:
def := strings.Replace(c.definition(), sql.PrimaryOption(), "", 1)
def, isPk := c.pkDefinition(false)
if isPk {
if _, isPrevPk := c.pkDefinition(true); isPrevPk {
// avoid repeat define primary key
def = strings.Replace(def, " "+sql.PrimaryOption(), "", 1)
}
}

return []string{fmt.Sprintf(sql.AlterTableModifyColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name)+def)}

case MigrateRevertAction:
prevDef, isPrevPk := c.pkDefinition(true)
if isPrevPk {
if _, isPk := c.pkDefinition(false); isPk {
// avoid repeat define primary key
prevDef = strings.Replace(prevDef, " "+sql.PrimaryOption(), "", 1)
}
}

return []string{fmt.Sprintf(sql.AlterTableModifyColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.Name)+prevDef)}

case MigrateRenameAction:
return []string{fmt.Sprintf(sql.AlterTableRenameColumnStm(), sql.EscapeSqlName(tbName), sql.EscapeSqlName(c.OldName), sql.EscapeSqlName(c.Name))}

Expand All @@ -103,12 +126,12 @@ func (c Column) migrationUp(tbName, after string, ident int) []string {
}

func (c Column) migrationCommentUp(tbName string) []string {
if c.Comment == "" || sql.GetDialect() != sql_templates.PostgresDialect {
if c.CurrentAttr.Comment == "" || sql.GetDialect() != sql_templates.PostgresDialect {
return nil
}

// apply for postgres only
return []string{fmt.Sprintf(sql.ColumnComment(), tbName, c.Name, c.Comment)}
return []string{fmt.Sprintf(sql.ColumnComment(), tbName, c.Name, c.CurrentAttr.Comment)}
}

func (c Column) migrationDown(tbName, after string) []string {
Expand All @@ -123,7 +146,7 @@ func (c Column) migrationDown(tbName, after string) []string {
c.Action = MigrateAddAction

case MigrateModifyAction:
return nil
c.Action = MigrateRevertAction

case MigrateRenameAction:
c.Name, c.OldName = c.OldName, c.Name
Expand All @@ -135,10 +158,19 @@ func (c Column) migrationDown(tbName, after string) []string {
return c.migrationUp(tbName, after, -1)
}

func (c Column) definition() string {
strSql := c.typeDefinition()
func (c Column) pkDefinition(isPrev bool) (string, bool) {
attr := c.CurrentAttr
if isPrev {
attr = c.PreviousAttr
}
strSql := c.typeDefinition(isPrev)

isPrimaryKey := false
for _, opt := range attr.Options {
if opt.Tp == ast.ColumnOptionPrimaryKey {
isPrimaryKey = true
}

for _, opt := range c.Options {
b := bytes.NewBufferString("")
var ctx *format.RestoreCtx

Expand All @@ -157,17 +189,27 @@ func (c Column) definition() string {
strSql += " " + b.String()
}

return strSql
return strSql, isPrimaryKey
}

func (c Column) definition(isPrev bool) string {
def, _ := c.pkDefinition(isPrev)
return def
}

func (c Column) typeDefinition() string {
func (c Column) typeDefinition(isPrev bool) string {
attr := c.CurrentAttr
if isPrev {
attr = c.PreviousAttr
}

switch {
case sql.IsPostgres() && c.PgType != nil:
return " " + c.PgType.SQLString()
case sql.IsSqlite() && c.LiteType != nil:
return " " + c.LiteType.Name.Name
case c.MysqlType != nil:
return " " + c.MysqlType.String()
case sql.IsPostgres() && attr.PgType != nil:
return " " + attr.PgType.SQLString()
case sql.IsSqlite() && attr.LiteType != nil:
return " " + attr.LiteType.Name.Name
case attr.MysqlType != nil:
return " " + attr.MysqlType.String()
}

return "" // column type is empty
Expand Down
2 changes: 1 addition & 1 deletion element/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (m *Migration) AddComment(tbName, colName, comment string) {
return
}

m.Tables[id].Columns[colIdx].Comment = comment
m.Tables[id].Columns[colIdx].CurrentAttr.Comment = comment
}

// AddIndex ...
Expand Down
2 changes: 2 additions & 0 deletions element/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ const (
MigrateRemoveAction
// MigrateModifyAction ...
MigrateModifyAction
// MigrateRevertAction ...
MigrateRevertAction
// MigrateRenameAction ...
MigrateRenameAction
)
Expand Down
19 changes: 10 additions & 9 deletions element/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,18 @@ func (t *Table) AddColumn(col Column) {
t.Columns[id] = col

default:
t.Columns[id].Options = append(t.Columns[id].Options, col.Options...)
t.Columns[id].CurrentAttr.Options = append(t.Columns[id].CurrentAttr.Options, col.CurrentAttr.Options...)

if size := len(t.Columns[id].Options); size > 0 {
for i := range t.Columns[id].Options[:size-1] {
if t.Columns[id].Options[i].Tp == ast.ColumnOptionPrimaryKey {
t.Columns[id].Options[i], t.Columns[id].Options[size-1] = t.Columns[id].Options[size-1], t.Columns[id].Options[i]
if size := len(t.Columns[id].CurrentAttr.Options); size > 0 {
for i := range t.Columns[id].CurrentAttr.Options[:size-1] {
if t.Columns[id].CurrentAttr.Options[i].Tp == ast.ColumnOptionPrimaryKey {
t.Columns[id].CurrentAttr.Options[i], t.Columns[id].CurrentAttr.Options[size-1] = t.Columns[id].CurrentAttr.Options[size-1], t.Columns[id].CurrentAttr.Options[i]
break
}
}
}

t.Columns[id].MysqlType = col.MysqlType
t.Columns[id].CurrentAttr.MysqlType = col.CurrentAttr.MysqlType
return
}

Expand Down Expand Up @@ -291,10 +291,11 @@ func (t *Table) Diff(old Table) {
for i := range t.Columns {
if j := old.getIndexColumn(t.Columns[i].Name); t.Columns[i].Action == MigrateAddAction &&
j >= 0 && old.Columns[j].Action != MigrateNoAction {
if hasChangedMysqlOptions(t.Columns[i].Options, old.Columns[j].Options) ||
hasChangedMysqlType(t.Columns[i].MysqlType, old.Columns[j].MysqlType) ||
hasChangePostgresType(t.Columns[i].PgType, old.Columns[j].PgType) {
if hasChangedMysqlOptions(t.Columns[i].CurrentAttr.Options, old.Columns[j].CurrentAttr.Options) ||
hasChangedMysqlType(t.Columns[i].CurrentAttr.MysqlType, old.Columns[j].CurrentAttr.MysqlType) ||
hasChangePostgresType(t.Columns[i].CurrentAttr.PgType, old.Columns[j].CurrentAttr.PgType) {
t.Columns[i].Action = MigrateModifyAction
t.Columns[i].PreviousAttr = old.Columns[j].CurrentAttr
} else {
t.Columns[i].Action = MigrateNoAction
}
Expand Down
42 changes: 25 additions & 17 deletions sql-parser/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) {
})
} else {
p.Migration.AddColumn(alter.Table.Text(), element.Column{
Node: element.Node{Name: cols[0], Action: element.MigrateAddAction},
MysqlType: nil,
Options: []*ast.ColumnOption{
{
Tp: ast.ColumnOptionPrimaryKey,
Node: element.Node{Name: cols[0], Action: element.MigrateAddAction},
CurrentAttr: element.SqlAttr{
MysqlType: nil,
Options: []*ast.ColumnOption{
{
Tp: ast.ColumnOptionPrimaryKey,
},
},
},
})
Expand Down Expand Up @@ -113,9 +115,11 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) {
if len(alter.Specs[i].NewColumns) > 0 {
for j := range alter.Specs[i].NewColumns {
col := element.Column{
Node: element.Node{Name: alter.Specs[i].NewColumns[j].Name.Name.O, Action: element.MigrateModifyAction},
MysqlType: alter.Specs[i].NewColumns[j].Tp,
Comment: alter.Specs[i].Comment,
Node: element.Node{Name: alter.Specs[i].NewColumns[j].Name.Name.O, Action: element.MigrateModifyAction},
CurrentAttr: element.SqlAttr{
MysqlType: alter.Specs[i].NewColumns[j].Tp,
Comment: alter.Specs[i].Comment,
},
}
p.Migration.AddColumn(alter.Table.Name.O, col)
}
Expand Down Expand Up @@ -161,11 +165,13 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) {
})
} else {
tb.AddColumn(element.Column{
Node: element.Node{Name: cols[0], Action: element.MigrateAddAction},
MysqlType: nil,
Options: []*ast.ColumnOption{
{
Tp: ast.ColumnOptionPrimaryKey,
Node: element.Node{Name: cols[0], Action: element.MigrateAddAction},
CurrentAttr: element.SqlAttr{
MysqlType: nil,
Options: []*ast.ColumnOption{
{
Tp: ast.ColumnOptionPrimaryKey,
},
},
},
})
Expand Down Expand Up @@ -218,10 +224,12 @@ func (p *Parser) Enter(in ast.Node) (ast.Node, bool) {
}

column := element.Column{
Node: element.Node{Name: def.Name.Name.O, Action: element.MigrateAddAction},
MysqlType: def.Tp,
Options: def.Options,
Comment: comment,
Node: element.Node{Name: def.Name.Name.O, Action: element.MigrateAddAction},
CurrentAttr: element.SqlAttr{
MysqlType: def.Tp,
Options: def.Options,
Comment: comment,
},
}
p.Migration.AddColumn("", column)
}
Expand Down
26 changes: 16 additions & 10 deletions sql-parser/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,24 @@ func (p *Parser) walker(ctx interface{}, node interface{}) (stop bool) {

case *tree.AlterTableAlterColumnType:
col := element.Column{
Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction},
PgType: nc.ToType,
Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction},
CurrentAttr: element.SqlAttr{
PgType: nc.ToType,
},
}
p.Migration.AddColumn(n.Table.String(), col)

case *tree.AlterTableSetDefault:
if nc.Default != nil {
col := element.Column{
Node: element.Node{Name: nc.Column.String(), Action: element.MigrateModifyAction},
Options: []*ast.ColumnOption{{
Expr: nil,
Tp: ast.ColumnOptionDefaultValue,
StrValue: nc.Default.String(),
}},
CurrentAttr: element.SqlAttr{
Options: []*ast.ColumnOption{{
Expr: nil,
Tp: ast.ColumnOptionDefaultValue,
StrValue: nc.Default.String(),
}},
},
}
p.Migration.AddColumn(n.Table.String(), col)
}
Expand Down Expand Up @@ -166,9 +170,11 @@ func postgresColumn(n *tree.ColumnTableDef) (element.Column, []element.Index) {
}

return element.Column{
Node: element.Node{Name: n.Name.String(), Action: element.MigrateAddAction},
PgType: n.Type,
Options: opts,
Node: element.Node{Name: n.Name.String(), Action: element.MigrateAddAction},
CurrentAttr: element.SqlAttr{
PgType: n.Type,
Options: opts,
},
}, indexes
}

Expand Down
Loading
Loading