Skip to content

Commit 37fbdbc

Browse files
hiroyuki-satokou
andauthored
GH-35419: [GLib] Add GArrowFixedShapeTensorDataType (#46305)
### Rationale for this change The C++ API implemented FixedShapeTensor Extension DataType. GLib was not yet supported. ### What changes are included in this PR? Implement GArrowFixedShapeTensorDataType. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * GitHub Issue: #35419 Lead-authored-by: Hiroyuki Sato <hiroysato@gmail.com> Co-authored-by: Sutou Kouhei <kou@cozmixng.org> Signed-off-by: Sutou Kouhei <kou@clear-code.com>
1 parent dc09722 commit 37fbdbc

File tree

3 files changed

+197
-0
lines changed

3 files changed

+197
-0
lines changed

c_glib/arrow-glib/basic-data-type.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <arrow-glib/type.hpp>
2727

2828
#include <arrow/c/bridge.h>
29+
#include <arrow/extension/fixed_shape_tensor.h>
2930

3031
G_BEGIN_DECLS
3132

@@ -131,6 +132,8 @@ G_BEGIN_DECLS
131132
* #GArrowStringViewDataType is a class for the string view data type.
132133
*
133134
* #GArrowBinaryViewDataType is a class for the binary view data type.
135+
*
136+
* #GArrowFixedShapeTensorDataType is a class for the fixed shape tensor data type.
134137
*/
135138

136139
struct GArrowDataTypePrivate
@@ -2267,6 +2270,83 @@ garrow_string_view_data_type_new(void)
22672270
return data_type;
22682271
}
22692272

2273+
G_DEFINE_TYPE(GArrowFixedShapeTensorDataType,
2274+
garrow_fixed_shape_tensor_data_type,
2275+
GARROW_TYPE_EXTENSION_DATA_TYPE)
2276+
2277+
static void
2278+
garrow_fixed_shape_tensor_data_type_init(GArrowFixedShapeTensorDataType *object)
2279+
{
2280+
}
2281+
2282+
static void
2283+
garrow_fixed_shape_tensor_data_type_class_init(GArrowFixedShapeTensorDataTypeClass *klass)
2284+
{
2285+
}
2286+
2287+
/**
2288+
* garrow_fixed_shape_tensor_data_type_new:
2289+
* @value_type: A #GArrowDataType of individual tensor elements.
2290+
* @shape: (array length=shape_length): A physical shape of the contained tensors as an
2291+
* array.
2292+
* @shape_length: The length of `shape`.
2293+
* @permutation: (array length=permutation_length) (nullable): An indices of the desired
2294+
* ordering of the original dimensions, defined as an array. This must be `NULL` or
2295+
* the same length array of `shape`.
2296+
* @permutation_length: The length of `permutation`.
2297+
* @dim_names: (array length=n_dim_names) (nullable): Explicit names to tensor dimensions
2298+
* as an array. This must be `NULL` or the same length array of `shape`.
2299+
* @n_dim_names. The length of `dim_names`.
2300+
* @error: (nullable): Return location for a #GError or %NULL.
2301+
*
2302+
* Returns: The newly created fixed shape tensor data type.
2303+
*/
2304+
GArrowFixedShapeTensorDataType *
2305+
garrow_fixed_shape_tensor_data_type_new(GArrowDataType *value_type,
2306+
const gint64 *shape,
2307+
gsize shape_length,
2308+
const gint64 *permutation,
2309+
gsize permutation_length,
2310+
const gchar **dim_names,
2311+
gsize n_dim_names,
2312+
GError **error)
2313+
{
2314+
std::vector<int64_t> arrow_shape;
2315+
std::vector<int64_t> arrow_permutation;
2316+
std::vector<std::string> arrow_dim_names;
2317+
2318+
auto arrow_value_type = garrow_data_type_get_raw(value_type);
2319+
2320+
for (int i = 0; i < shape_length; i++) {
2321+
arrow_shape.push_back(shape[i]);
2322+
}
2323+
2324+
for (int i = 0; i < permutation_length; i++) {
2325+
arrow_permutation.push_back(permutation[i]);
2326+
}
2327+
2328+
for (int i = 0; i < n_dim_names; i++) {
2329+
arrow_dim_names.push_back(dim_names[i]);
2330+
}
2331+
2332+
auto arrow_data_type_result =
2333+
arrow::extension::FixedShapeTensorType::Make(arrow_value_type,
2334+
arrow_shape,
2335+
arrow_permutation,
2336+
arrow_dim_names);
2337+
if (!garrow::check(error, arrow_data_type_result, "[fixed-shape-tensor][new]")) {
2338+
return NULL;
2339+
}
2340+
2341+
auto arrow_data_type = *arrow_data_type_result;
2342+
auto data_type = GARROW_FIXED_SHAPE_TENSOR_DATA_TYPE(
2343+
g_object_new(GARROW_TYPE_FIXED_SHAPE_TENSOR_DATA_TYPE,
2344+
"data-type",
2345+
&arrow_data_type,
2346+
NULL));
2347+
return data_type;
2348+
}
2349+
22702350
G_END_DECLS
22712351

22722352
GArrowDataType *

c_glib/arrow-glib/basic-data-type.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,4 +802,28 @@ GARROW_AVAILABLE_IN_19_0
802802
GArrowStringViewDataType *
803803
garrow_string_view_data_type_new(void);
804804

805+
#define GARROW_TYPE_FIXED_SHAPE_TENSOR_DATA_TYPE \
806+
(garrow_fixed_shape_tensor_data_type_get_type())
807+
GARROW_AVAILABLE_IN_21_0
808+
G_DECLARE_DERIVABLE_TYPE(GArrowFixedShapeTensorDataType,
809+
garrow_fixed_shape_tensor_data_type,
810+
GARROW,
811+
FIXED_SHAPE_TENSOR_DATA_TYPE,
812+
GArrowExtensionDataType)
813+
struct _GArrowFixedShapeTensorDataTypeClass
814+
{
815+
GArrowExtensionDataTypeClass parent_class;
816+
};
817+
818+
GARROW_AVAILABLE_IN_21_0
819+
GArrowFixedShapeTensorDataType *
820+
garrow_fixed_shape_tensor_data_type_new(GArrowDataType *value_type,
821+
const gint64 *shape,
822+
gsize shape_length,
823+
const gint64 *permutation,
824+
gsize permutation_length,
825+
const gchar **dim_names,
826+
gsize n_dim_names,
827+
GError **error);
828+
805829
G_END_DECLS
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
class TestFixedShapeTensorDataType < Test::Unit::TestCase
19+
def test_id
20+
data_type = Arrow::FixedShapeTensorDataType.new(Arrow::UInt64DataType.new,
21+
[3, 4],
22+
[1, 0],
23+
["x", "y"])
24+
assert_equal(Arrow::Type::EXTENSION, data_type.id)
25+
end
26+
27+
def test_name
28+
data_type = Arrow::FixedShapeTensorDataType.new(Arrow::UInt64DataType.new,
29+
[3, 4],
30+
[1, 0],
31+
["x", "y"])
32+
assert_equal(["extension", "arrow.fixed_shape_tensor"],
33+
[data_type.name, data_type.extension_name])
34+
end
35+
36+
def test_to_s
37+
data_type = Arrow::FixedShapeTensorDataType.new(Arrow::UInt64DataType.new,
38+
[3, 4],
39+
[1, 0],
40+
["x", "y"])
41+
assert do
42+
data_type.to_s.start_with?("extension<arrow.fixed_shape_tensor")
43+
end
44+
end
45+
46+
def test_nil_permutation
47+
data_type = Arrow::FixedShapeTensorDataType.new(Arrow::UInt64DataType.new,
48+
[3, 4],
49+
nil,
50+
["x", "y"])
51+
# TODO: Use Arrow::FixedShapeTensorDataType#permutation
52+
assert_equal(Arrow::Type::EXTENSION, data_type.id)
53+
end
54+
55+
def test_nil_dim_names
56+
data_type = Arrow::FixedShapeTensorDataType.new(Arrow::UInt64DataType.new,
57+
[3, 4],
58+
[0, 1],
59+
nil)
60+
# TODO: Use Arrow::FixedShapeTensorDataType#dim_names
61+
assert_equal(Arrow::Type::EXTENSION, data_type.id)
62+
end
63+
64+
def test_mismatch_permutation_size
65+
message =
66+
"[fixed-shape-tensor][new]: Invalid: " +
67+
"permutation size must match shape size. " +
68+
"Expected: 2 Got: 1"
69+
error = assert_raise(Arrow::Error::Invalid) do
70+
Arrow::FixedShapeTensorDataType.new(Arrow::UInt64DataType.new,
71+
[3, 4],
72+
[1],
73+
["x", "y"])
74+
end
75+
assert_equal(message,
76+
error.message.lines.first.chomp)
77+
end
78+
79+
def test_mismatch_dim_names_size
80+
message =
81+
"[fixed-shape-tensor][new]: Invalid: " +
82+
"dim_names size must match shape size. " +
83+
"Expected: 2 Got: 1"
84+
error = assert_raise(Arrow::Error::Invalid) do
85+
Arrow::FixedShapeTensorDataType.new(Arrow::UInt64DataType.new,
86+
[3, 4],
87+
[],
88+
["x"])
89+
end
90+
assert_equal(message,
91+
error.message.lines.first.chomp)
92+
end
93+
end

0 commit comments

Comments
 (0)