Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,30 @@ import (

// Pre-computed reflect type values to avoid repeated allocations.
var (
reflectTypeBool = reflect.TypeOf(true)
reflectTypeInt8 = reflect.TypeOf(int8(0))
reflectTypeInt16 = reflect.TypeOf(int16(0))
reflectTypeInt32 = reflect.TypeOf(int32(0))
reflectTypeInt64 = reflect.TypeOf(int64(0))
reflectTypeUint8 = reflect.TypeOf(uint8(0))
reflectTypeUint16 = reflect.TypeOf(uint16(0))
reflectTypeUint32 = reflect.TypeOf(uint32(0))
reflectTypeUint64 = reflect.TypeOf(uint64(0))
reflectTypeFloat32 = reflect.TypeOf(float32(0))
reflectTypeFloat64 = reflect.TypeOf(float64(0))
reflectTypeTime = reflect.TypeOf(time.Time{})
reflectTypeInterval = reflect.TypeOf(Interval{})
reflectTypeBigInt = reflect.TypeOf(big.NewInt(0))
reflectTypeString = reflect.TypeOf("")
reflectTypeBytes = reflect.TypeOf([]byte{})
reflectTypeDecimal = reflect.TypeOf(Decimal{})
reflectTypeSliceAny = reflect.TypeOf([]any{})
reflectTypeMapString = reflect.TypeOf(map[string]any{})
reflectTypeMap = reflect.TypeOf(Map{})
reflectTypeUnion = reflect.TypeOf(Union{})
reflectTypeBool = reflect.TypeFor[bool]()
reflectTypeInt8 = reflect.TypeFor[int8]()
reflectTypeInt16 = reflect.TypeFor[int16]()
reflectTypeInt32 = reflect.TypeFor[int32]()
reflectTypeInt64 = reflect.TypeFor[int64]()
reflectTypeUint8 = reflect.TypeFor[uint8]()
reflectTypeUint16 = reflect.TypeFor[uint16]()
reflectTypeUint32 = reflect.TypeFor[uint32]()
reflectTypeUint64 = reflect.TypeFor[uint64]()
reflectTypeFloat32 = reflect.TypeFor[float32]()
reflectTypeFloat64 = reflect.TypeFor[float64]()
reflectTypeTime = reflect.TypeFor[time.Time]()
reflectTypeInterval = reflect.TypeFor[Interval]()
reflectTypeBigInt = reflect.TypeFor[*big.Int]()
reflectTypeString = reflect.TypeFor[string]()
reflectTypeBytes = reflect.TypeFor[[]byte]()
reflectTypeDecimal = reflect.TypeFor[Decimal]()
reflectTypeSliceAny = reflect.TypeFor[[]any]()
reflectTypeMapString = reflect.TypeFor[map[string]any]()
reflectTypeMap = reflect.TypeFor[Map]()
reflectTypeUnion = reflect.TypeFor[Union]()
reflectTypeAny = reflect.TypeFor[any]()
reflectTypeUUID = reflect.TypeFor[UUID]()
reflectTypeHugeInt = reflect.TypeFor[mapping.HugeInt]()
)

type numericType interface {
Expand Down Expand Up @@ -104,13 +106,13 @@ func inferUUID(val any) (mapping.HugeInt, error) {
id = *v
case []uint8:
if len(v) != uuidLength {
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(id).String())
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflectTypeUUID.String())
}
for i := range uuidLength {
id[i] = v[i]
}
default:
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(id).String())
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflectTypeUUID.String())
}
hi := uuidToHugeInt(id)
return hi, nil
Expand Down Expand Up @@ -200,20 +202,20 @@ func inferHugeInt(val any) (mapping.HugeInt, error) {
}
case *big.Int:
if v == nil {
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String())
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflectTypeHugeInt.String())
}
if hi, err = hugeIntFromNative(v); err != nil {
return mapping.HugeInt{}, err
}
case Decimal:
if v.Value == nil {
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String())
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflectTypeHugeInt.String())
}
if hi, err = hugeIntFromNative(v.Value); err != nil {
return mapping.HugeInt{}, err
}
default:
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(hi).String())
return mapping.HugeInt{}, castError(reflect.TypeOf(val).String(), reflectTypeHugeInt.String())
}

return hi, nil
Expand Down Expand Up @@ -251,7 +253,7 @@ func inferInterval(val any) (mapping.Interval, error) {
case Interval:
i = v
default:
return mapping.Interval{}, castError(reflect.TypeOf(val).String(), reflect.TypeOf(i).String())
return mapping.Interval{}, castError(reflect.TypeOf(val).String(), reflectTypeInterval.String())
}
return mapping.NewInterval(i.Months, i.Days, i.Micros), nil
}
Expand Down Expand Up @@ -327,7 +329,7 @@ func castToTime(val any) (time.Time, error) {
case time.Time:
ti = v
default:
return ti, castError(reflect.TypeOf(val).String(), reflect.TypeOf(ti).String())
return ti, castError(reflect.TypeOf(val).String(), reflectTypeTime.String())
}
return ti.UTC(), nil
}
Expand Down
5 changes: 2 additions & 3 deletions types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"database/sql"
"fmt"
"math/big"
"reflect"
"strconv"
"testing"
"time"
Expand Down Expand Up @@ -1010,8 +1009,8 @@ func TestJSONColType(t *testing.T) {
require.Len(t, columnTypes, 2)
require.Equal(t, aliasJSON, columnTypes[0].DatabaseTypeName())
require.Equal(t, typeToStringMap[TYPE_BIGINT], columnTypes[1].DatabaseTypeName())
require.Equal(t, reflect.TypeOf((*any)(nil)).Elem(), columnTypes[0].ScanType())
require.Equal(t, reflect.TypeOf(int64(0)), columnTypes[1].ScanType())
require.Equal(t, reflectTypeAny, columnTypes[0].ScanType())
require.Equal(t, reflectTypeInt64, columnTypes[1].ScanType())
}

func TestUnionTypes(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions value.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ func createSliceValue[T any](lt mapping.LogicalType, t Type, val T) (mapping.Val
func createStructValue(lt mapping.LogicalType, val any) (mapping.Value, error) {
m, ok := val.(map[string]any)
if !ok {
return mapping.Value{}, castError(reflect.TypeOf(val).Name(), reflect.TypeOf(map[string]any{}).Name())
return mapping.Value{}, castError(reflect.TypeOf(val).Name(), reflectTypeMapString.Name())
}

var values []mapping.Value
Expand Down Expand Up @@ -502,7 +502,7 @@ func extractSlice[S any](val S) ([]any, error) {
default:
kind := reflect.TypeOf(val).Kind()
if kind != reflect.Array && kind != reflect.Slice {
return nil, castError(reflect.TypeOf(val).String(), reflect.TypeOf(s).String())
return nil, castError(reflect.TypeOf(val).String(), reflectTypeSliceAny.String())
}

// Insert the values into the slice.
Expand Down
10 changes: 5 additions & 5 deletions vector_setters.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func setBool[S any](vec *vector, rowIdx mapping.IdxT, val S) error {
case bool:
b = v
default:
return castError(reflect.TypeOf(val).String(), reflect.TypeOf(b).String())
return castError(reflect.TypeOf(val).String(), reflectTypeBool.String())
}
setPrimitive(vec, rowIdx, b)
return nil
Expand Down Expand Up @@ -204,7 +204,7 @@ func setEnum[S any](vec *vector, rowIdx mapping.IdxT, val S) error {
case string:
str = v
default:
return castError(reflect.TypeOf(val).String(), reflect.TypeOf(str).String())
return castError(reflect.TypeOf(val).String(), reflectTypeString.String())
}

if v, ok := vec.namesDict[str]; ok {
Expand All @@ -219,7 +219,7 @@ func setEnum[S any](vec *vector, rowIdx mapping.IdxT, val S) error {
return setNumeric[uint32, int64](vec, rowIdx, v)
}
} else {
return castError(reflect.TypeOf(val).String(), reflect.TypeOf(str).String())
return castError(reflect.TypeOf(val).String(), reflectTypeString.String())
}
return nil
}
Expand Down Expand Up @@ -250,7 +250,7 @@ func setStruct[S any](vec *vector, rowIdx mapping.IdxT, val S) error {

// Catch mismatching types.
goType := reflect.TypeOf(val)
if reflect.TypeOf(val).Kind() != reflect.Struct {
if goType.Kind() != reflect.Struct {
return castError(goType.String(), reflect.Struct.String())
}

Expand Down Expand Up @@ -294,7 +294,7 @@ func setMap[S any](vec *vector, rowIdx mapping.IdxT, val S) error {
case Map:
m = v
default:
return castError(reflect.TypeOf(val).String(), reflect.TypeOf(m).String())
return castError(reflect.TypeOf(val).String(), reflectTypeMap.String())
}

// Create a LIST of STRUCT values.
Expand Down