Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,35 +170,35 @@ func (s *Stmt) bindTimestamp(val driver.NamedValue, t Type, n int) (mapping.Stat
var state mapping.State
switch t {
case TYPE_TIMESTAMP:
v, err := getMappedTimestamp(t, val.Value)
v, err := inferTimestamp(t, val.Value)
if err != nil {
return mapping.StateError, err
}
state = mapping.BindTimestamp(*s.preparedStmt, mapping.IdxT(n+1), v)
case TYPE_TIMESTAMP_TZ:
v, err := getMappedTimestamp(t, val.Value)
v, err := inferTimestamp(t, val.Value)
if err != nil {
return mapping.StateError, err
}
state = mapping.BindTimestampTZ(*s.preparedStmt, mapping.IdxT(n+1), v)
case TYPE_TIMESTAMP_S:
v, err := getMappedTimestampS(val.Value)
v, err := inferTimestampS(val.Value)
if err != nil {
return mapping.StateError, err
}
tS := mapping.CreateTimestampS(v)
state = mapping.BindValue(*s.preparedStmt, mapping.IdxT(n+1), tS)
mapping.DestroyValue(&tS)
case TYPE_TIMESTAMP_MS:
v, err := getMappedTimestampMS(val.Value)
v, err := inferTimestampMS(val.Value)
if err != nil {
return mapping.StateError, err
}
tMS := mapping.CreateTimestampMS(v)
state = mapping.BindValue(*s.preparedStmt, mapping.IdxT(n+1), tMS)
mapping.DestroyValue(&tMS)
case TYPE_TIMESTAMP_NS:
v, err := getMappedTimestampNS(val.Value)
v, err := inferTimestampNS(val.Value)
if err != nil {
return mapping.StateError, err
}
Expand All @@ -210,7 +210,7 @@ func (s *Stmt) bindTimestamp(val driver.NamedValue, t Type, n int) (mapping.Stat
}

func (s *Stmt) bindDate(val driver.NamedValue, n int) (mapping.State, error) {
date, err := getMappedDate(val.Value)
date, err := inferDate(val.Value)
if err != nil {
return mapping.StateError, err
}
Expand Down Expand Up @@ -282,7 +282,7 @@ func (s *Stmt) bindCompositeValue(val driver.NamedValue, n int) (mapping.State,
}

func (s *Stmt) tryBindComplexValue(val driver.NamedValue, n int) (mapping.State, error) {
lt, mappedVal, err := createValueByReflection(val.Value)
lt, mappedVal, err := inferLogicalTypeAndValue(val.Value)
defer mapping.DestroyLogicalType(&lt)
defer mapping.DestroyValue(&mappedVal)
if err != nil {
Expand Down Expand Up @@ -393,7 +393,11 @@ func (s *Stmt) bindValue(val driver.NamedValue, n int) (mapping.State, error) {
case []byte:
return mapping.BindBlob(*s.preparedStmt, mapping.IdxT(n+1), v), nil
case Interval:
return mapping.BindInterval(*s.preparedStmt, mapping.IdxT(n+1), v.getMappedInterval()), nil
i, inferErr := inferInterval(v)
if inferErr != nil {
return mapping.StateError, inferErr
}
return mapping.BindInterval(*s.preparedStmt, mapping.IdxT(n+1), i), nil
case nil:
return mapping.BindNull(*s.preparedStmt, mapping.IdxT(n+1)), nil
}
Expand Down
2 changes: 1 addition & 1 deletion statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -842,5 +842,5 @@ func TestDriverValuer(t *testing.T) {
// Expected to fail - no driver.Valuer implementation
_, err = db.Exec(`INSERT INTO valuer_test (ids) VALUES (?)`, []uuid.UUID{uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"), uuid.MustParse("3a92e387-4b7d-4098-b273-967d48f6925f")})
require.Error(t, err, "[]uuid.UUID should fail without driver.Valuer")
require.Contains(t, err.Error(), "unsupported data type: UUID")
require.Contains(t, err.Error(), castErrMsg)
}
1 change: 1 addition & 0 deletions type.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const (
TYPE_ANY = mapping.TypeAny
TYPE_BIGNUM = mapping.TypeBigNum
TYPE_SQLNULL = mapping.TypeSQLNull
// TODO: add TYPE_TIME_NS here, or support it.
)

// FIXME: Implement support for these types.
Expand Down
125 changes: 117 additions & 8 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ import (
"github.com/marcboeker/go-duckdb/mapping"
)

// Precomputed reflect type values to avoid repeated allocations.
// go-duckdb exports the following type wrappers:
// UUID, Map, Interval, Decimal, Union, Composite (optional, used to scan LIST and STRUCT).

// Pre-computed reflect type values to avoid repeated allocations.
var (
reflectTypeBool = reflect.TypeOf(true)
reflectTypeInt8 = reflect.TypeOf(int8(0))
Expand Down Expand Up @@ -92,6 +95,27 @@ func (u *UUID) Value() (driver.Value, error) {
return u.String(), nil
}

func inferUUID(val any) (mapping.HugeInt, error) {
var id UUID
switch v := val.(type) {
case UUID:
id = v
case *UUID:
id = *v
case []uint8:
if len(v) != uuidLength {
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(id).String())
}
for i := range uuidLength {
id[i] = v[i]
}
default:
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(id).String())
}
hi := uuidToHugeInt(id)
return hi, nil
}

// duckdb_hugeint is composed of (lower, upper) components.
// The value is computed as: upper * 2^64 + lower

Expand Down Expand Up @@ -134,6 +158,67 @@ func hugeIntFromNative(i *big.Int) (mapping.HugeInt, error) {
return mapping.NewHugeInt(r.Uint64(), q.Int64()), nil
}

func inferHugeInt(val any) (mapping.HugeInt, error) {
var err error
var hi mapping.HugeInt
switch v := val.(type) {
case uint8:
hi = mapping.NewHugeInt(uint64(v), 0)
case int8:
hi = mapping.NewHugeInt(uint64(v), 0)
case uint16:
hi = mapping.NewHugeInt(uint64(v), 0)
case int16:
hi = mapping.NewHugeInt(uint64(v), 0)
case uint32:
hi = mapping.NewHugeInt(uint64(v), 0)
case int32:
hi = mapping.NewHugeInt(uint64(v), 0)
case uint64:
hi = mapping.NewHugeInt(v, 0)
case int64:
hi, err = hugeIntFromNative(big.NewInt(v))
if err != nil {
return mapping.HugeInt{}, err
}
case uint:
hi = mapping.NewHugeInt(uint64(v), 0)
case int:
hi, err = hugeIntFromNative(big.NewInt(int64(v)))
if err != nil {
return mapping.HugeInt{}, err
}
case float32:
hi, err = hugeIntFromNative(big.NewInt(int64(v)))
if err != nil {
return mapping.HugeInt{}, err
}
case float64:
hi, err = hugeIntFromNative(big.NewInt(int64(v)))
if err != nil {
return mapping.HugeInt{}, err
}
case *big.Int:
if v == nil {
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String())
}
if hi, err = hugeIntFromNative(v); err != nil {
return mapping.HugeInt{}, err
}
case Decimal:
if v.Value == nil {
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String())
}
if hi, err = hugeIntFromNative(v.Value); err != nil {
return mapping.HugeInt{}, err
}
default:
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String())
}

return hi, nil
}

type Map map[any]any

func (m *Map) Scan(v any) error {
Expand All @@ -160,8 +245,15 @@ type Interval struct {
Micros int64 `json:"micros"`
}

func (i *Interval) getMappedInterval() mapping.Interval {
return mapping.NewInterval(i.Months, i.Days, i.Micros)
func inferInterval(val any) (mapping.Interval, error) {
var i Interval
switch v := val.(type) {
case Interval:
i = v
default:
return mapping.Interval{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(i).String())
}
return mapping.NewInterval(i.Months, i.Days, i.Micros), nil
}

// Composite can be used as the `Scanner` type for any composite types (maps, lists, structs).
Expand Down Expand Up @@ -268,27 +360,27 @@ func getTSTicks(t Type, val any) (int64, error) {
return ti.UnixNano(), nil
}

func getMappedTimestamp(t Type, val any) (mapping.Timestamp, error) {
func inferTimestamp(t Type, val any) (mapping.Timestamp, error) {
ticks, err := getTSTicks(t, val)
return mapping.NewTimestamp(ticks), err
}

func getMappedTimestampS(val any) (mapping.TimestampS, error) {
func inferTimestampS(val any) (mapping.TimestampS, error) {
ticks, err := getTSTicks(TYPE_TIMESTAMP_S, val)
return mapping.NewTimestampS(ticks), err
}

func getMappedTimestampMS(val any) (mapping.TimestampMS, error) {
func inferTimestampMS(val any) (mapping.TimestampMS, error) {
ticks, err := getTSTicks(TYPE_TIMESTAMP_MS, val)
return mapping.NewTimestampMS(ticks), err
}

func getMappedTimestampNS(val any) (mapping.TimestampNS, error) {
func inferTimestampNS(val any) (mapping.TimestampNS, error) {
ticks, err := getTSTicks(TYPE_TIMESTAMP_NS, val)
return mapping.NewTimestampNS(ticks), err
}

func getMappedDate[T any](val T) (mapping.Date, error) {
func inferDate[T any](val T) (mapping.Date, error) {
ti, err := castToTime(val)
if err != nil {
return mapping.Date{}, err
Expand All @@ -298,6 +390,23 @@ func getMappedDate[T any](val T) (mapping.Date, error) {
return date, err
}

func inferTime(val any) (mapping.Time, error) {
ticks, err := getTimeTicks(val)
if err != nil {
return mapping.Time{}, err
}
return mapping.NewTime(ticks), nil
}

func inferTimeTZ(val any) (mapping.TimeTZ, error) {
ticks, err := getTimeTicks(val)
if err != nil {
return mapping.TimeTZ{}, err
}
// The UTC offset is 0.
return mapping.CreateTimeTZ(ticks, 0), nil
}

func getTimeTicks[T any](val T) (int64, error) {
ti, err := castToTime(val)
if err != nil {
Expand Down
47 changes: 47 additions & 0 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1098,3 +1098,50 @@ func TestUnionTypes(t *testing.T) {
require.Equal(t, int32(123), val.Value)
})
}

func TestInferPrimitiveType(t *testing.T) {
db := openDbWrapper(t, ``)
defer closeDbWrapper(t, db)

testCases := []struct {
input any
}{
{[]Map{nil}},
{[]bool{true, false}},
{[]int8{-7}},
{[]int16{-42}},
{[]int32{-4}},
{[]int64{-6}},
{[]int{-22}},
{[]uint8{7}},
{[]uint16{42}},
{[]uint32{4}},
{[]uint64{6}},
{[]uint{22}},
{[]float32{7.8}},
{[]float64{22.3}},
{[]string{"Hello from Amsterdam!"}},
{[][]byte{{71, 111}}},
{[]time.Time{time.Now()}},
{[]Interval{{22, 10, 7}}},
{[]*big.Int{big.NewInt(22)}},
{[]Decimal{{2, 2, big.NewInt(7)}}},
{[]UUID{UUID(uuid.New())}},
}
for _, tc := range testCases {
_, err := db.Exec(`SELECT a FROM (VALUES (?)) t(a)`, tc.input)
require.NoError(t, err)
}

// Not yet supported.
testCases = []struct {
input any
}{
{[]Union{{42, "n"}}},
{[]Map{map[any]any{"hello": "world", "beautiful": "day"}}},
}
for _, tc := range testCases {
_, err := db.Exec(`SELECT a FROM (VALUES (?)) t(a)`, tc.input)
require.ErrorContains(t, err, unsupportedTypeErrMsg)
}
}
Loading
Loading