Skip to content

Commit 9b9a318

Browse files
authored
[XPU][HOST] add unstack op and kernels (#4876) (#4892)
1 parent dc3676e commit 9b9a318

File tree

12 files changed

+492
-0
lines changed

12 files changed

+492
-0
lines changed

lite/kernels/host/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
message(STATUS "compile with lite host kernels")
22

3+
# basic kernels
34
add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
45
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
56
add_kernel(range_compute_host Host basic SRCS range_compute.cc DEPS ${lite_kernel_deps})
@@ -12,6 +13,9 @@ add_kernel(expand_as_compute_host Host basic SRCS expand_as_compute.cc DEPS ${li
1213
add_kernel(fill_constant_compute_host Host basic SRCS fill_constant_compute.cc DEPS ${lite_kernel_deps})
1314
add_kernel(fill_constant_batch_size_like_compute_host Host basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps})
1415
add_kernel(stack_compute_host Host basic SRCS stack_compute.cc DEPS ${lite_kernel_deps})
16+
17+
# extra kernels
18+
add_kernel(unstack_compute_host Host extra SRCS unstack_compute.cc DEPS ${lite_kernel_deps})
1519
add_kernel(shape_compute_host Host extra SRCS shape_compute.cc DEPS ${lite_kernel_deps})
1620
add_kernel(is_empty_compute_host Host extra SRCS is_empty_compute.cc DEPS ${lite_kernel_deps})
1721
add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps})

lite/kernels/host/unstack_compute.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/host/unstack_compute.h"
16+
#include <cstring>
17+
#include <vector>
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace host {
23+
24+
template <typename T, PrecisionType PType>
25+
void UnstackCompute<T, PType>::Run() {
26+
auto& param = this->template Param<operators::UnstackParam>();
27+
auto x = param.X;
28+
auto outs = param.Out;
29+
auto x_dims = x->dims();
30+
int axis = param.axis;
31+
if (axis < 0) {
32+
axis += x_dims.size();
33+
}
34+
35+
size_t stride_copy = 1;
36+
for (size_t i = axis + 1; i < x_dims.size(); i++) {
37+
stride_copy *= static_cast<size_t>(x_dims[i]);
38+
}
39+
size_t stride_move = stride_copy * static_cast<size_t>(x_dims[axis]);
40+
size_t copy_times = static_cast<size_t>(x_dims.production()) / stride_move;
41+
42+
const T* x_data = x->template data<T>();
43+
for (size_t i = 0; i < outs.size(); i++) {
44+
const T* x_ptr = x_data + i * stride_copy;
45+
T* out_ptr = outs[i]->template mutable_data<T>();
46+
for (size_t j = 0; j < copy_times; j++) {
47+
std::memcpy(out_ptr, x_ptr, sizeof(T) * stride_copy);
48+
x_ptr += stride_move;
49+
out_ptr += stride_copy;
50+
}
51+
}
52+
}
53+
54+
} // namespace host
55+
} // namespace kernels
56+
} // namespace lite
57+
} // namespace paddle
58+
59+
using unstack_float =
60+
paddle::lite::kernels::host::UnstackCompute<float, PRECISION(kFloat)>;
61+
REGISTER_LITE_KERNEL(unstack, kHost, kFloat, kAny, unstack_float, def)
62+
.BindInput("X",
63+
{LiteType::GetTensorTy(
64+
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
65+
.BindOutput("Y",
66+
{LiteType::GetTensorTy(
67+
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
68+
.Finalize();

lite/kernels/host/unstack_compute.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace host {
23+
24+
template <typename T, PrecisionType PType>
25+
class UnstackCompute
26+
: public KernelLite<TARGET(kHost), PType, DATALAYOUT(kAny)> {
27+
public:
28+
void Run() override;
29+
30+
virtual ~UnstackCompute() = default;
31+
};
32+
33+
} // namespace host
34+
} // namespace kernels
35+
} // namespace lite
36+
} // namespace paddle

lite/kernels/xpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ else()
3939
add_kernel(var_conv_2d_compute_xpu XPU extra SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps})
4040
add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps})
4141
add_kernel(sequence_unpad_compute_xpu XPU extra SRCS sequence_unpad_compute.cc DEPS ${lite_kernel_deps})
42+
add_kernel(unstack_compute_xpu XPU extra SRCS unstack_compute.cc DEPS ${lite_kernel_deps})
4243

4344
# extra(fused kernel)
4445
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc DEPS ${lite_kernel_deps})

lite/kernels/xpu/unstack_compute.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/unstack_compute.h"
16+
#include <vector>
17+
#include "lite/backends/xpu/xpu_header_sitter.h"
18+
#include "lite/core/op_registry.h"
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace kernels {
23+
namespace xpu {
24+
25+
void UnstackCompute::Run() {
26+
auto& param = this->Param<param_t>();
27+
auto& ctx = this->ctx_->As<XPUContext>();
28+
auto& dout = param.Out;
29+
auto in_dim = param.X->dims();
30+
int axis = param.axis;
31+
if (axis < 0) {
32+
axis += in_dim.size();
33+
}
34+
int num = param.num;
35+
36+
int height = 1;
37+
for (int i = 0; i < axis; i++) {
38+
height = height * in_dim[i];
39+
}
40+
41+
std::vector<float*> out_ptrs;
42+
std::vector<int> width_out;
43+
44+
for (auto out : dout) {
45+
out->set_lod(param.X->lod());
46+
out_ptrs.push_back(out->mutable_data<float>(TARGET(kXPU)));
47+
width_out.push_back(out->numel() / height);
48+
}
49+
50+
int r = xdnn::concat_grad(ctx.GetRawContext(),
51+
height,
52+
width_out.data(),
53+
num,
54+
out_ptrs.data(),
55+
param.X->data<float>());
56+
CHECK_EQ(r, 0);
57+
}
58+
59+
} // namespace xpu
60+
} // namespace kernels
61+
} // namespace lite
62+
} // namespace paddle
63+
64+
REGISTER_LITE_KERNEL(unstack,
65+
kXPU,
66+
kFloat,
67+
kNCHW,
68+
paddle::lite::kernels::xpu::UnstackCompute,
69+
def)
70+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
71+
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
72+
.Finalize();

lite/kernels/xpu/unstack_compute.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace kernels {
21+
namespace xpu {
22+
23+
class UnstackCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
24+
public:
25+
using param_t = operators::UnstackParam;
26+
27+
virtual void Run();
28+
29+
virtual ~UnstackCompute() = default;
30+
};
31+
32+
} // namespace xpu
33+
} // namespace kernels
34+
} // namespace lite
35+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS})
123123
add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS})
124124
add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS})
125125
add_operator(scatter extra SRCS scatter_op.cc DEPS ${op_DEPS})
126+
add_operator(unstack_op extra SRCS unstack_op.cc DEPS ${op_DEPS})
126127

127128
# for OCR specific
128129
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})

lite/operators/op_params.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,15 @@ struct StackParam : ParamBase {
222222
int axis{0};
223223
};
224224

225+
// For Unstack Op
226+
struct UnstackParam : ParamBase {
227+
const lite::Tensor* X{nullptr};
228+
std::vector<lite::Tensor*> Out{};
229+
230+
int axis{0};
231+
int num{1};
232+
};
233+
225234
// For Power Op
226235
struct PowerParam : ParamBase {
227236
const lite::Tensor* X{};

lite/operators/unstack_op.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/operators/unstack_op.h"
16+
#include "lite/core/op_lite.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace operators {
22+
23+
bool UnstackOp::CheckShape() const {
24+
CHECK(param_.X);
25+
for (auto out : param_.Out) {
26+
CHECK(out);
27+
}
28+
return true;
29+
}
30+
31+
bool UnstackOp::InferShapeImpl() const {
32+
auto x = param_.X;
33+
auto outs = param_.Out;
34+
int axis = param_.axis;
35+
if (axis < 0) {
36+
axis += x->dims().size();
37+
}
38+
int num = param_.num;
39+
auto x_shape = x->dims().Vectorize();
40+
CHECK_EQ(x_shape[axis], static_cast<int64_t>(num))
41+
<< "num(attr) should be equal to x_dims[axis]. But received x_dims: "
42+
<< x->dims() << ", axis: " << param_.axis << ", num: " << num;
43+
44+
auto out_shape = x_shape;
45+
out_shape.erase(out_shape.begin() + axis);
46+
for (auto out : outs) {
47+
out->Resize(out_shape);
48+
}
49+
return true;
50+
}
51+
52+
bool UnstackOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
53+
param_.X = scope->FindTensor(op_desc.Input("X").front());
54+
auto out_names = op_desc.Output("Y");
55+
for (auto out_name : out_names) {
56+
param_.Out.emplace_back(scope->FindMutableTensor(out_name));
57+
}
58+
59+
param_.axis = op_desc.GetAttr<int>("axis");
60+
param_.num = op_desc.GetAttr<int>("num");
61+
return true;
62+
}
63+
64+
} // namespace operators
65+
} // namespace lite
66+
} // namespace paddle
67+
68+
REGISTER_LITE_OP(unstack, paddle::lite::operators::UnstackOp);

lite/operators/unstack_op.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include "lite/core/op_lite.h"
18+
#include "lite/core/scope.h"
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace operators {
23+
24+
class UnstackOp : public OpLite {
25+
public:
26+
UnstackOp() {}
27+
28+
explicit UnstackOp(const std::string &op_type) : OpLite(op_type) {}
29+
30+
bool CheckShape() const override;
31+
32+
bool InferShapeImpl() const override;
33+
34+
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
35+
36+
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
37+
std::string DebugString() const override { return "unstack"; }
38+
39+
private:
40+
mutable UnstackParam param_;
41+
};
42+
43+
} // namespace operators
44+
} // namespace lite
45+
} // namespace paddle

lite/tests/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ if(LITE_BUILD_EXTRA)
6666
lite_cc_test(test_kernel_clip_compute SRCS clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${rknpu_kernels} ${apu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
6767
lite_cc_test(test_kernel_pixel_shuffle_compute SRCS pixel_shuffle_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${rknpu_kernels} ${apu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
6868
lite_cc_test(test_kernel_scatter_compute SRCS scatter_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${rknpu_kernels} ${apu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
69+
lite_cc_test(test_kernel_unstack_compute SRCS unstack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${rknpu_kernels} ${apu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
6970
lite_cc_test(test_kernel_sequence_expand_as_compute SRCS sequence_expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${rknpu_kernels} ${apu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
7071

7172
# for training kernel

0 commit comments

Comments
 (0)