Skip to content

Support OCT() function and fix CONV() mishandling of negative floats and empty string N #3020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 11, 2025
Merged
72 changes: 72 additions & 0 deletions enginetest/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8387,6 +8387,78 @@ SELECT * FROM cte WHERE d = 2;`,
Query: "SELECT CONV(i, 10, 2) FROM mytable",
Expected: []sql.Row{{"1"}, {"10"}, {"11"}},
},
{
Query: "SELECT OCT(8)",
Expected: []sql.Row{{"10"}},
},
{
Query: "SELECT OCT(255)",
Expected: []sql.Row{{"377"}},
},
{
Query: "SELECT OCT(0)",
Expected: []sql.Row{{"0"}},
},
{
Query: "SELECT OCT(1)",
Expected: []sql.Row{{"1"}},
},
{
Query: "SELECT OCT(NULL)",
Expected: []sql.Row{{nil}},
},
{
Query: "SELECT OCT(-1)",
Expected: []sql.Row{{"1777777777777777777777"}},
},
{
Query: "SELECT OCT(-8)",
Expected: []sql.Row{{"1777777777777777777770"}},
},
{
Query: "SELECT OCT(OCT(4))",
Expected: []sql.Row{{"4"}},
},
{
Query: "SELECT OCT('16')",
Expected: []sql.Row{{"20"}},
},
{
Query: "SELECT OCT('abc')",
Expected: []sql.Row{{"0"}},
},
{
Query: "SELECT OCT(15.7)",
Expected: []sql.Row{{"17"}},
},
{
Query: "SELECT OCT(-15.2)",
Expected: []sql.Row{{"1777777777777777777761"}},
},
{
Query: "SELECT OCT(HEX(SUBSTRING('127.0', 1, 3)))",
Expected: []sql.Row{{"1143625"}},
},
{
Query: "SELECT i, OCT(i), OCT(-i), OCT(i * 2) FROM mytable ORDER BY i",
Expected: []sql.Row{
{1, "1", "1777777777777777777777", "2"},
{2, "2", "1777777777777777777776", "4"},
{3, "3", "1777777777777777777775", "6"},
},
},
{
Query: "SELECT OCT(i) FROM mytable ORDER BY CONV(i, 10, 16)",
Expected: []sql.Row{{"1"}, {"2"}, {"3"}},
},
{
Query: "SELECT i FROM mytable WHERE OCT(s) > 0",
Expected: []sql.Row{},
},
{
Query: "SELECT s FROM mytable WHERE OCT(i*123) < 400",
Expected: []sql.Row{{"first row"}, {"second row"}},
},
{
Query: `SELECT t1.pk from one_pk join (one_pk t1 join one_pk t2 on t1.pk = t2.pk) on t1.pk = one_pk.pk and one_pk.pk = 1 join (one_pk t3 join one_pk t4 on t3.c1 is not null) on t3.pk = one_pk.pk and one_pk.c1 = 10`,
Expected: []sql.Row{{1}, {1}, {1}, {1}},
Expand Down
60 changes: 32 additions & 28 deletions sql/expression/function/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,62 +136,66 @@ func (c *Conv) WithChildren(children ...sql.Expression) (sql.Expression, error)
// This conversion truncates nVal as its first subpart that is convertable.
// nVal is treated as unsigned except nVal is negative.
func convertFromBase(ctx *sql.Context, nVal string, fromBase interface{}) interface{} {
fromBase, _, err := types.Int64.Convert(ctx, fromBase)
if err != nil {
if len(nVal) == 0 {
return nil
}

fromVal := int(math.Abs(float64(fromBase.(int64))))
// Convert and validate fromBase
baseVal, _, err := types.Int64.Convert(ctx, fromBase)
if err != nil {
return nil
}
fromVal := int(math.Abs(float64(baseVal.(int64))))
if fromVal < 2 || fromVal > 36 {
return nil
}

// Handle sign
negative := false
var upper string
var lower string
if nVal[0] == '-' {
switch nVal[0] {
case '-':
if len(nVal) == 1 {
return uint64(0)
}
negative = true
nVal = nVal[1:]
} else if nVal[0] == '+' {
case '+':
if len(nVal) == 1 {
return uint64(0)
}
nVal = nVal[1:]
}

// check for upper and lower bound for given fromBase
// Determine bounds based on sign
var maxLen int
if negative {
upper = strconv.FormatInt(math.MaxInt64, fromVal)
lower = strconv.FormatInt(math.MinInt64, fromVal)
if len(nVal) > len(lower) {
nVal = lower
} else if len(nVal) > len(upper) {
nVal = upper
maxLen = len(strconv.FormatInt(math.MinInt64, fromVal))
if len(nVal) > maxLen {
// Use MinInt64 representation in the given base
nVal = strconv.FormatInt(math.MinInt64, fromVal)[1:] // remove minus sign
}
} else {
upper = strconv.FormatUint(math.MaxUint64, fromVal)
lower = "0"
if len(nVal) < len(lower) {
nVal = lower
} else if len(nVal) > len(upper) {
nVal = upper
maxLen = len(strconv.FormatUint(math.MaxUint64, fromVal))
if len(nVal) > maxLen {
// Use MaxUint64 representation in the given base
nVal = strconv.FormatUint(math.MaxUint64, fromVal)
}
}

truncate := false
result := uint64(0)
i := 1
for !truncate && i <= len(nVal) {
// Find the longest valid prefix that can be converted
var result uint64
for i := 1; i <= len(nVal); i++ {
val, err := strconv.ParseUint(nVal[:i], fromVal, 64)
if err != nil {
truncate = true
return result
break
}
result = val
i++
}

if negative {
// MySQL returns signed value for negative inputs
return int64(result) * -1
}

return result
}

Expand Down
2 changes: 2 additions & 0 deletions sql/expression/function/conv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ func TestConv(t *testing.T) {
{"n is nil", types.Int32, sql.NewRow(nil, 16, 2), nil},
{"fromBase is nil", types.LongText, sql.NewRow('a', nil, 2), nil},
{"toBase is nil", types.LongText, sql.NewRow('a', 16, nil), nil},
{"empty n string", types.LongText, sql.NewRow("", 3, 4), nil},
{"empty arg strings", types.LongText, sql.NewRow(4, "", ""), nil},

// invalid inputs
{"invalid N", types.LongText, sql.NewRow("r", 16, 2), "0"},
Expand Down
91 changes: 91 additions & 0 deletions sql/expression/function/oct.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package function

import (
"fmt"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
)

// Oct function provides a string representation for the octal value of N, where N is a decimal (base 10) number.
type Oct struct {
n sql.Expression
}

var _ sql.FunctionExpression = (*Oct)(nil)
var _ sql.CollationCoercible = (*Oct)(nil)

// NewOct returns a new Oct expression.
func NewOct(n sql.Expression) sql.Expression { return &Oct{n} }

// FunctionName implements sql.FunctionExpression.
func (o *Oct) FunctionName() string {
return "oct"
}

// Description implements sql.FunctionExpression.
func (o *Oct) Description() string {
return "returns a string representation for octal value of N, where N is a decimal (base 10) number."
}

// Type implements the Expression interface.
func (o *Oct) Type() sql.Type {
return types.LongText
}

// IsNullable implements the Expression interface.
func (o *Oct) IsNullable() bool {
return o.n.IsNullable()
}

// Eval implements the Expression interface.
func (o *Oct) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
// Convert a decimal (base 10) number to octal (base 8)
return NewConv(
o.n,
expression.NewLiteral(10, types.Int64),
expression.NewLiteral(8, types.Int64),
).Eval(ctx, row)
}

// Resolved implements the Expression interface.
func (o *Oct) Resolved() bool {
return o.n.Resolved()
}

// Children implements the Expression interface.
func (o *Oct) Children() []sql.Expression {
return []sql.Expression{o.n}
}

// WithChildren implements the Expression interface.
func (o *Oct) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(o, len(children), 1)
}
return NewOct(children[0]), nil
}

func (o *Oct) String() string {
return fmt.Sprintf("%s(%s)", o.FunctionName(), o.n)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (*Oct) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return ctx.GetCollation(), 4 // strings with collations
}
80 changes: 80 additions & 0 deletions sql/expression/function/oct_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package function

import (
"math"
"testing"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
)

type test struct {
name string
nType sql.Type
row sql.Row
expected interface{}
}

func TestOct(t *testing.T) {
tests := []test{
// NULL input
{"n is nil", types.Int32, sql.NewRow(nil), nil},

// Positive numbers
{"positive small", types.Int32, sql.NewRow(8), "10"},
{"positive medium", types.Int32, sql.NewRow(64), "100"},
{"positive large", types.Int32, sql.NewRow(4095), "7777"},
{"positive huge", types.Int64, sql.NewRow(123456789), "726746425"},

// Negative numbers
{"negative small", types.Int32, sql.NewRow(-8), "1777777777777777777770"},
{"negative medium", types.Int32, sql.NewRow(-64), "1777777777777777777700"},
{"negative large", types.Int32, sql.NewRow(-4095), "1777777777777777770001"},

// Zero
{"zero", types.Int32, sql.NewRow(0), "0"},

// String inputs
{"string number", types.LongText, sql.NewRow("15"), "17"},
{"alpha string", types.LongText, sql.NewRow("abc"), "0"},
{"mixed string", types.LongText, sql.NewRow("123abc"), "173"},

// Edge cases
{"max int32", types.Int32, sql.NewRow(math.MaxInt32), "17777777777"},
{"min int32", types.Int32, sql.NewRow(math.MinInt32), "1777777777760000000000"},
{"max int64", types.Int64, sql.NewRow(math.MaxInt64), "777777777777777777777"},
{"min int64", types.Int64, sql.NewRow(math.MinInt64), "1000000000000000000000"},

// Decimal numbers
{"decimal", types.Float64, sql.NewRow(15.5), "17"},
{"negative decimal", types.Float64, sql.NewRow(-15.5), "1777777777777777777761"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := NewOct(expression.NewGetField(0, tt.nType, "n", true))
result, err := f.Eval(sql.NewEmptyContext(), tt.row)
if err != nil {
t.Fatal(err)
}
if result != tt.expected {
t.Errorf("got %v; expected %v", result, tt.expected)
}
})
}
}
1 change: 1 addition & 0 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ var BuiltIns = []sql.Function{
sql.Function1{Name: "ntile", Fn: window.NewNTile},
sql.FunctionN{Name: "now", Fn: NewNow},
sql.Function2{Name: "nullif", Fn: NewNullIf},
sql.Function1{Name: "oct", Fn: NewOct},
sql.Function1{Name: "octet_length", Fn: NewLength},
sql.Function1{Name: "ord", Fn: NewOrd},
sql.Function0{Name: "pi", Fn: NewPi},
Expand Down