Skip to content

Commit a9f55d3

Browse files
authored
fix(query): Fix cast array to vector failed (#18375)
* fix(query): Fix cast array to vector failed * add tests
1 parent ce909b6 commit a9f55d3

File tree

3 files changed

+39
-0
lines changed

3 files changed

+39
-0
lines changed

src/query/expression/src/evaluator.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,7 @@ impl<'a> Evaluator<'a> {
941941
.set_span(span));
942942
}
943943
let mut vals = Vec::with_capacity(dimension);
944+
let col = col.remove_nullable();
944945
match col {
945946
Column::Number(num_col) => {
946947
for i in 0..dimension {

tests/sqllogictests/suites/udf_server/udf_server_test.test

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,36 @@ EvalScalar
578578
├── estimated rows: 1.00
579579
└── DummyTableScan
580580

581+
582+
statement ok
583+
CREATE OR REPLACE FUNCTION embedding_4 (VARCHAR) RETURNS ARRAY(FLOAT NULL) LANGUAGE python IMMUTABLE HANDLER = 'embedding_4' ADDRESS = 'http://0.0.0.0:8815';
584+
585+
query T
586+
SELECT embedding_4('databend.com')::vector(4);
587+
----
588+
[1.1,1.2,1.3,1.4]
589+
590+
statement ok
591+
CREATE OR REPLACE TABLE test(url STRING, length INT64);
592+
593+
statement ok
594+
INSERT INTO test (url) VALUES('databend.com'),('databend.cn');
595+
596+
query T
597+
SELECT embedding_4('databend.com')::vector(4) fro
598+
----
599+
[1.1,1.2,1.3,1.4]
600+
601+
query T
602+
SELECT embedding_4(url)::vector(4) FROM test;
603+
----
604+
[1.1,1.2,1.3,1.4]
605+
[1.1,1.2,1.3,1.4]
606+
607+
statement ok
608+
drop FUNCTION embedding_4;
609+
610+
581611
statement ok
582612
remove @udf_stage;
583613

tests/udf/udf_server.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,12 @@ def ping(s: str) -> str:
414414
def check_headers() -> str:
415415
return "success"
416416

417+
@udf(
418+
input_types=["VARCHAR"],
419+
result_type="ARRAY(FLOAT32 NULL)"
420+
)
421+
def embedding_4(s: str):
422+
return [1.1, 1.2, 1.3, 1.4]
417423

418424
if __name__ == "__main__":
419425
udf_server = CheckHeadersServer(
@@ -444,7 +450,9 @@ def check_headers() -> str:
444450
udf_server.add_function(wait_concurrent)
445451
udf_server.add_function(url_len)
446452
udf_server.add_function(check_headers)
453+
udf_server.add_function(embedding_4)
447454

448455
# Built-in function
449456
udf_server.add_function(ping)
450457
udf_server.serve()
458+

0 commit comments

Comments
 (0)