diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index 25a461d949..b7cfbde026 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -3105,6 +3105,76 @@ namespace dlib } } + // ------------------------------------------------------------------------------------ + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dk, size_t dnr, size_t dnc, + const tensor& src, + size_t sk, size_t snr, size_t snc, + size_t k, size_t nr, size_t nc + ) + { + size_t dest_stride_sample = static_cast(dest.nc() * dest.nr() * dest.k()); + size_t dest_stride_k = static_cast(dest.nc() * dest.nr()); + size_t dest_stride_nr = static_cast(dest.nc()); + + size_t src_stride_sample = static_cast(src.nc() * src.nr() * src.k()); + size_t src_stride_k = static_cast(src.nc() * src.nr()); + size_t src_stride_nr = static_cast(src.nc()); + + DLIB_CASSERT(dest.num_samples() == src.num_samples(), "All sources should fit into dest tensor size"); + DLIB_CASSERT(dest.k() - dk >= k && + dest.nr() - dnr >= nr && + dest.nc() - dnc >= nc, "Not enough space in dest tensor"); + DLIB_CASSERT(src.k() - sk >= k && + src.nr() - snr >= nr && + src.nc() - snc >= nc, "Not enough space in src tensor"); + + float* dest_p = dest.host() + dk * dest_stride_k \ + + dnr * dest_stride_nr \ + + dnc; + + const float* src_p = src.host() + sk * src_stride_k \ + + snr * src_stride_nr \ + + snc; + + for (long i = 0; i < src.num_samples(); ++i) + { + float* dest_channel_p = dest_p; + const float* src_channel_p = src_p; + + for (long j = 0; j < k; ++j) + { + float* dest_row_p = dest_channel_p; + const float* src_row_p = src_channel_p; + + for (long r = 0; r < nr; ++r) + { + if (add_to) + { + for (size_t c = 0; c < nc; ++c) + dest_row_p[c] += src_row_p[c]; + } + else + { + ::memcpy(dest_row_p, src_row_p, nc * sizeof(float)); + } + + dest_row_p += dest_stride_nr; + src_row_p += src_stride_nr; + } + + dest_channel_p += dest_stride_k; + src_channel_p += src_stride_k; + } + + dest_p += dest_stride_sample; + src_p += src_stride_sample; + } + } + // ------------------------------------------------------------------------------------ void transpose( diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index f35b3c9728..ab88a4a4ee 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -692,6 +692,17 @@ namespace dlib size_t count_k ); + // ----------------------------------------------------------------------------------- + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dk, size_t dnr, size_t dnc, + const tensor& src, + size_t sk, size_t snr, size_t snc, + size_t k, size_t nr, size_t nc + ); + // ----------------------------------------------------------------------------------- void transpose( diff --git a/dlib/cuda/cuda_dlib.cu b/dlib/cuda/cuda_dlib.cu index c650e89bee..5d6ec4052c 100644 --- a/dlib/cuda/cuda_dlib.cu +++ b/dlib/cuda/cuda_dlib.cu @@ -2623,6 +2623,77 @@ namespace dlib } } + __global__ void _cuda_copy_strided_tensor_add_to (float* dest, const float* src, + size_t ns, size_t nk, size_t nr, size_t nc, + size_t dk, size_t dr, size_t dc, + size_t sk, size_t sr, size_t sc) + { + for(auto i : grid_stride_range(0, ns*nk*nr*nc)) + { + size_t n,k,r,c; + unpack_idx(i, nk,nr,nc, n,k,r,c); + dest[pack_idx(dk,dr,dc, n,k,r,c)] += src[pack_idx(sk,sr,sc, n,k,r,c)]; + } + } + + __global__ void _cuda_copy_strided_tensor (float* dest, const float* src, + size_t ns, size_t nk, size_t nr, size_t nc, + size_t dk, size_t dr, size_t dc, + size_t sk, size_t sr, size_t sc) + { + for(auto i : grid_stride_range(0, ns*nk*nr*nc)) + { + size_t n,k,r,c; + unpack_idx(i, nk,nr,nc, n,k,r,c); + dest[pack_idx(dk,dr,dc, n,k,r,c)] = src[pack_idx(sk,sr,sc, n,k,r,c)]; + } + } + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dk, size_t dnr, size_t dnc, + const tensor& src, + size_t sk, size_t snr, size_t snc, + size_t k, size_t nr, size_t nc + ) + { + + DLIB_CASSERT(dest.num_samples() == src.num_samples(), "All sources should fit into dest tensor size"); + DLIB_CASSERT(dest.k() - dk >= k && + dest.nr() - dnr >= nr && + dest.nc() - dnc >= nc, "Not enough space in dest tensor"); + DLIB_CASSERT(src.k() - sk >= k && + src.nr() - snr >= nr && + src.nc() - snc >= nc, "Not enough space in src tensor"); + + float* dest_p = dest.device() + dk * static_cast(dest.nc() * dest.nr()) \ + + dnr * static_cast(dest.nc()) \ + + dnc; + + const float* src_p = src.device() + sk * static_cast(src.nc() * src.nr()) \ + + snr * static_cast(src.nc()) \ + + snc; + + if (add_to) + { + launch_kernel(_cuda_copy_strided_tensor_add_to, max_jobs(dest.size()), + dest_p, src_p, dest.num_samples(), + k, nr, nc, + dest.k(), dest.nr(), dest.nc(), + src.k(), src.nr(), src.nc()); + } + else + { + launch_kernel(_cuda_copy_strided_tensor, max_jobs(dest.size()), + dest_p, src_p, dest.num_samples(), + k, nr, nc, + dest.k(), dest.nr(), dest.nc(), + src.k(), src.nr(), src.nc()); + } + } + + // ---------------------------------------------------------------------------------------- __global__ void _cuda_transpose(size_t dsize, size_t dk, size_t dnr, size_t dnc, float* d, diff --git a/dlib/cuda/cuda_dlib.h b/dlib/cuda/cuda_dlib.h index dab3627b1b..2f22b7e23e 100644 --- a/dlib/cuda/cuda_dlib.h +++ b/dlib/cuda/cuda_dlib.h @@ -589,6 +589,17 @@ namespace dlib size_t count_k ); + // ---------------------------------------------------------------------------------------- + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dk, size_t dnr, size_t dnc, + const tensor& src, + size_t sk, size_t snr, size_t snc, + size_t k, size_t nr, size_t nc + ); + // ---------------------------------------------------------------------------------------- void transpose( diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index 069b4d4659..8dece8369f 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -1333,6 +1333,24 @@ namespace dlib { namespace tt #endif } +// ---------------------------------------------------------------------------------------- + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dk, size_t dnr, size_t dnc, + const tensor& src, + size_t sk, size_t snr, size_t snc, + size_t k, size_t nr, size_t nc + ) + { +#ifdef DLIB_USE_CUDA + cuda::copy_tensor(add_to, dest, dk, dnr, dnc , src, sk, snr, snc, k, nr, nc); +#else + cpu::copy_tensor(add_to, dest, dk, dnr, dnc, src, sk, snr, snc, k, nr, nc); +#endif + } + // ---------------------------------------------------------------------------------------- void inv:: diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index 17649603d9..18a5564f98 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -2334,6 +2334,38 @@ namespace dlib { namespace tt i.e., copies content of each sample from src in to corresponding place of sample at dest. !*/ +// ---------------------------------------------------------------------------------------- + + void copy_tensor( + bool add_to, + tensor& dest, + size_t dk, size_t dnr, size_t dnc, + const tensor& src, + size_t sk, size_t snr, size_t snc, + size_t k, size_t nr, size_t nc + ); + /*! + requires + - dest.num_samples() == src.num_samples() + - dest.k() - dk >= k + - dest.nr() - dnr >= nr + - dest.nc() - dnc >= nc + - src.k() - sk >= k + - src.nr() - snr >= nr + - src.nc() - snc >= nc + - is_same_object(dest,src) == false + - The memory areas of src and dest do not overlap. + ensures + - if (add_to) then + - performs: dest[i, j + dk, r + dnr, c + dnc] += src[i, j + sk, r + snr, c + snc], where j in [0..k], + r in [0..nr] and c in [0..nc] + i.e., adds content of each sample from src in to corresponding place of sample at dest. + - else + - performs: dest[i, j + dk, r + dnr, c + dnc] = src[i, j + sk, r + snr, c +snc], where j in [0..k], + r in [0..nr] and c in [0..nc] + i.e., copies content of each sample from src in to corresponding place of sample at dest. + !*/ + // ---------------------------------------------------------------------------------------- void transpose( diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 023ccbf810..0a0c547f33 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -4631,6 +4631,131 @@ namespace dlib > using extract = add_layer, SUBNET>; +// ---------------------------------------------------------------------------------------- + + template < + long _offset_k, + long _offset_nr, + long _offset_nc, + long _k, + long _nr, + long _nc + > + class slice_ + { + static_assert(_offset_k >= 0, "The channel offset must be >= 0."); + static_assert(_offset_nr >= 0, "The row offset must be >= 0."); + static_assert(_offset_nc >= 0, "The column offset must be >= 0."); + static_assert(_k > 0, "The number of channels must be > 0."); + static_assert(_nr > 0, "The number of rows must be > 0."); + static_assert(_nc > 0, "The number of columns must be > 0."); + public: + slice_( + ) + { + } + + template + void setup (const SUBNET& sub) + { + DLIB_CASSERT((long)sub.get_output().size() >= sub.get_output().num_samples()*(_offset_k+_offset_nr+_offset_nc+_k*_nr*_nc), + "The tensor we are trying to slice from the input tensor is too big to fit into the input tensor."); + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + output.set_size(sub.get_output().num_samples(), _k, _nr, _nc); + tt::copy_tensor(false, output, 0, 0, 0, sub.get_output(), _offset_k, _offset_nr, _offset_nc, _k, _nr, _nc); + } + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + tt::copy_tensor(true, sub.get_gradient_input(), _offset_k, _offset_nr, _offset_nc, gradient_input, 0, 0, 0, _k, _nr, _nc); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const slice_& /*item*/, std::ostream& out) + { + serialize("slice_", out); + serialize(_offset_k, out); + serialize(_offset_nr, out); + serialize(_offset_nc, out); + serialize(_k, out); + serialize(_nr, out); + serialize(_nc, out); + } + + friend void deserialize(slice_& /*item*/, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "slice_") + throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::slice_."); + + long offset_k; + long offset_nr; + long offset_nc; + long k; + long nr; + long nc; + deserialize(offset_k, in); + deserialize(offset_nr, in); + deserialize(offset_nc, in); + deserialize(k, in); + deserialize(nr, in); + deserialize(nc, in); + + if (offset_k != _offset_k) throw serialization_error("Wrong offset_k found while deserializing dlib::slice_"); + if (offset_nr != _offset_nr) throw serialization_error("Wrong offset_nr found while deserializing dlib::slice_"); + if (offset_nc != _offset_nc) throw serialization_error("Wrong offset_nc found while deserializing dlib::slice_"); + if (k != _k) throw serialization_error("Wrong k found while deserializing dlib::slice_"); + if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::slice_"); + if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::slice_"); + } + + friend std::ostream& operator<<(std::ostream& out, const slice_& /*item*/) + { + out << "slice\t (" + << "offset_k="<<_offset_k + << "offset_nr="<<_offset_nr + << "offset_nc="<<_offset_nc + << ", k="<<_k + << ", nr="<<_nr + << ", nc="<<_nc + << ")"; + return out; + } + + friend void to_xml(const slice_& /*item*/, std::ostream& out) + { + out << "\n"; + } + private: + resizable_tensor params; // unused + }; + + template < + long offset_k, + long offset_nr, + long offset_nc, + long k, + long nr, + long nc, + typename SUBNET + > + using slice = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- template diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index 99fe91401c..ef2de8e6fe 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -3699,6 +3699,85 @@ namespace dlib > using extract = add_layer, SUBNET>; +// ---------------------------------------------------------------------------------------- + + template < + long _offset_k, + long _offset_nr, + long _offset_nc, + long _k, + long _nr, + long _nc + > + class slice_ + { + /*! + REQUIREMENTS ON TEMPLATE ARGUMENTS + - 0 <= _offset_k + - 0 <= _offset_nr + - 0 <= _offset_nc + - 0 < _k + - 0 < _nr + - 0 < _nc + + WHAT THIS OBJECT REPRESENTS + This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface + defined above. In particular, the output of this layer is simply a copy of + the input tensor. It is similar to extract in that you can configure the + slice layer to output only some subset of the input tensor, but slice allows + copies of non-contiguous regions of the input which enables three dimensional + cropping of a tensor. The dimensions of the tensor output by this layer + are as follows (letting IN be the input tensor and OUT the output tensor): + - OUT.num_samples() == IN.num_samples() + - OUT.k() == _k + - OUT.nr() == _nr + - OUT.nc() == _nc + + So the output will always have the same number of samples as the input, but + within each sample (the k,nr,nc part) we will copy only a subset of the + values. Moreover, the _offset_k, _offset_nr, and _offset_nc parameters + control which channels, rows, and columns of each sample we take. + To be very precise, we will have: + - let IN_SIZE = IN.k()*IN.nr()*IN.nc() + - let OUT_SIZE = _k*_nr*_nc + - for i in range[0,IN.num_samples()) and j in range[0,OUT_SIZE): + - let k = (j / (OUT.nr()*OUT.nc())) % OUT.k() + - let r = (j / OUT.nc()) % IN.nr() + - let c = j % OUT.nc() + - OUT.host()[i*OUT_SIZE+j] == IN.host()[i*IN_SIZE+ + k_stride*(_offset_k+k)+ + row_stride*(_offset_nr+r)+ + col_stride*(_offset_nc+c)] + + + Finally, all this means that the input tensor to this layer must have a big + enough size to accommodate taking a _k*_nr*_nc slice from each of its + samples. + !*/ + + public: + + template void setup (const SUBNET& sub); + template void forward(const SUBNET& sub, resizable_tensor& output); + template void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + const tensor& get_layer_params() const; + tensor& get_layer_params(); + /*! + These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface. + !*/ + }; + + template < + long offset_k, + long offset_nr, + long offset_nc, + long k, + long nr, + long nc, + typename SUBNET + > + using slice = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- template diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index a99123e7f8..9316a0edc6 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -2209,6 +2209,24 @@ void test_embeddings() auto res = test_layer(l); DLIB_TEST_MSG(res, res); } + { + print_spinner(); + slice_<0,0,0,2,2,2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + slice_<1,1,1,1,1,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } + { + print_spinner(); + slice_<0,0,0,1,1,1> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } { print_spinner(); upsample_<1,1> l; @@ -2751,6 +2769,140 @@ void test_embeddings() } } } + void test_copy_tensor_slice_cpu() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest(10, 9, 7, 15); + resizable_tensor src1(10, 3, 7, 15); + resizable_tensor src2(10, 3, 6, 12); + resizable_tensor src3(10, 9, 7, 15); + tt::tensor_rand rnd; + rnd.fill_gaussian(dest); + rnd.fill_gaussian(src1); + rnd.fill_gaussian(src2); + rnd.fill_gaussian(src3); + + const resizable_tensor old_dest = dest; + + cpu::copy_tensor(false, dest, 0, 0, 0, src1, 0, 0, 0, src1.k(), src1.nr(), src1.nc()); //full copy src1->dest + cpu::copy_tensor(false, dest, src1.k(), 0, 0, src2, 0, 0, 0, src2.k(), src2.nr(), src2.nc()); //full copy src2->dest with offset of src1 + cpu::copy_tensor(false, dest, src1.k() + src2.k(), 1, 1, src3, 3, 1, 1, 3, src3.nr()-2, src3.nc()-2); //partial copy src3 into the rest place of dest + + + for (long i = 0; i < dest.num_samples(); ++i) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float old_dest_value = tensor_read_cpu(old_dest, i, k, r, c); + float dest_value = tensor_read_cpu(dest, i, k, r, c); + // first part is from src1 + if (k < src1.k()) + { + float src_value = tensor_read_cpu(src1, i, k, r, c); + DLIB_TEST(src_value == dest_value); + } + // second part is from src2 + else if (k < src1.k() + src2.k()) + { + if (r < src2.nr() && c < src2.nc()) + { + float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c); + DLIB_TEST(src_value == dest_value); + } + else + { + DLIB_TEST(old_dest_value == dest_value); + } + } + // third part is from src3 + else + { + if (r > 0 && c > 0 && r + 1 < src3.nr() && c + 1 < src3.nc()) + { + float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c); + DLIB_TEST(src_value == dest_value); + } + else { + DLIB_TEST(old_dest_value == dest_value); + } + } + } + } + } + } + } + void test_copy_tensor_slice_add_to_cpu() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest(10, 9, 7, 15); + resizable_tensor src1(10, 3, 7, 15); + resizable_tensor src2(10, 3, 6, 12); + resizable_tensor src3(10, 9, 7, 15); + tt::tensor_rand rnd; + rnd.fill_gaussian(dest); + rnd.fill_gaussian(src1); + rnd.fill_gaussian(src2); + rnd.fill_gaussian(src3); + + const resizable_tensor old_dest = dest; + + cpu::copy_tensor(true, dest, 0, 0, 0, src1, 0, 0, 0, src1.k(), src1.nr(), src1.nc()); //full copy src1->dest + cpu::copy_tensor(true, dest, src1.k(), 0, 0, src2, 0, 0, 0, src2.k(), src2.nr(), src2.nc()); //full copy src2->dest with offset of src1 + cpu::copy_tensor(true, dest, src1.k() + src2.k(), 1, 1, src3, 3, 1, 1, 3, src3.nr()-2, src3.nc()-2); //partial copy src3 into the rest place of dest + + for (long i = 0; i < dest.num_samples(); ++i) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float old_dest_value = tensor_read_cpu(old_dest, i, k, r, c); + float dest_value = tensor_read_cpu(dest, i, k, r, c); + // first part is from src1 + if (k < src1.k()) + { + float src_value = tensor_read_cpu(src1, i, k, r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + // second part is from src2 + else if (k < src1.k() + src2.k()) + { + if (r < src2.nr() && c < src2.nc()) + { + float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + else + { + DLIB_TEST(old_dest_value == dest_value); + } + } + // third part is from src3 + else + { + if (r > 0 && c > 0 && r + 1 < src3.nr() && c + 1 < src3.nc()) + { + float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + else + { + DLIB_TEST(old_dest_value == dest_value); + } + } + } + } + } + } + } #ifdef DLIB_USE_CUDA void test_copy_tensor_gpu() { @@ -2856,6 +3008,140 @@ void test_embeddings() } } } + void test_copy_tensor_slice_gpu() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest(10, 9, 7, 15); + resizable_tensor src1(10, 3, 7, 15); + resizable_tensor src2(10, 3, 6, 12); + resizable_tensor src3(10, 9, 7, 15); + tt::tensor_rand rnd; + rnd.fill_gaussian(dest); + rnd.fill_gaussian(src1); + rnd.fill_gaussian(src2); + rnd.fill_gaussian(src3); + + const resizable_tensor old_dest = dest; + + cuda::copy_tensor(false, dest, 0, 0, 0, src1, 0, 0, 0, src1.k(), src1.nr(), src1.nc()); //full copy src1->dest + cuda::copy_tensor(false, dest, src1.k(), 0, 0, src2, 0, 0, 0, src2.k(), src2.nr(), src2.nc()); //full copy src2->dest with offset of src1 + cuda::copy_tensor(false, dest, src1.k() + src2.k(), 1, 1, src3, 3, 1, 1, 3, src3.nr()-2, src3.nc()-2); //partial copy src3 into the rest place of dest + + + for (long i = 0; i < dest.num_samples(); ++i) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float old_dest_value = tensor_read_cpu(old_dest, i, k, r, c); + float dest_value = tensor_read_cpu(dest, i, k, r, c); + // first part is from src1 + if (k < src1.k()) + { + float src_value = tensor_read_cpu(src1, i, k, r, c); + DLIB_TEST(src_value == dest_value); + } + // second part is from src2 + else if (k < src1.k() + src2.k()) + { + if (r < src2.nr() && c < src2.nc()) + { + float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c); + DLIB_TEST(src_value == dest_value); + } + else + { + DLIB_TEST(old_dest_value == dest_value); + } + } + // third part is from src3 + else + { + if (r > 0 && c > 0 && r + 1 < src3.nr() && c + 1 < src3.nc()) + { + float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c); + DLIB_TEST(src_value == dest_value); + } + else { + DLIB_TEST(old_dest_value == dest_value); + } + } + } + } + } + } + } + void test_copy_tensor_slice_add_to_gpu() + { + using namespace dlib::tt; + print_spinner(); + resizable_tensor dest(10, 9, 7, 15); + resizable_tensor src1(10, 3, 7, 15); + resizable_tensor src2(10, 3, 6, 12); + resizable_tensor src3(10, 9, 7, 15); + tt::tensor_rand rnd; + rnd.fill_gaussian(dest); + rnd.fill_gaussian(src1); + rnd.fill_gaussian(src2); + rnd.fill_gaussian(src3); + + const resizable_tensor old_dest = dest; + + cuda::copy_tensor(true, dest, 0, 0, 0, src1, 0, 0, 0, src1.k(), src1.nr(), src1.nc()); //full copy src1->dest + cuda::copy_tensor(true, dest, src1.k(), 0, 0, src2, 0, 0, 0, src2.k(), src2.nr(), src2.nc()); //full copy src2->dest with offset of src1 + cuda::copy_tensor(true, dest, src1.k() + src2.k(), 1, 1, src3, 3, 1, 1, 3, src3.nr()-2, src3.nc()-2); //partial copy src3 into the rest place of dest + + for (long i = 0; i < dest.num_samples(); ++i) + { + for (long k = 0; k < dest.k(); ++k) + { + for (long r = 0; r < dest.nr(); ++r) + { + for (long c = 0; c < dest.nc(); ++c) + { + float old_dest_value = tensor_read_cpu(old_dest, i, k, r, c); + float dest_value = tensor_read_cpu(dest, i, k, r, c); + // first part is from src1 + if (k < src1.k()) + { + float src_value = tensor_read_cpu(src1, i, k, r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + // second part is from src2 + else if (k < src1.k() + src2.k()) + { + if (r < src2.nr() && c < src2.nc()) + { + float src_value = tensor_read_cpu(src2, i, k - src1.k(), r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + else + { + DLIB_TEST(old_dest_value == dest_value); + } + } + // third part is from src3 + else + { + if (r > 0 && c > 0 && r + 1 < src3.nr() && c + 1 < src3.nc()) + { + float src_value = tensor_read_cpu(src3, i, k - src1.k() - src2.k() + 3, r, c)+old_dest_value; + DLIB_TEST(std::abs(src_value - dest_value) < 1e-6); + } + else + { + DLIB_TEST(old_dest_value == dest_value); + } + } + } + } + } + } + } #endif//DLIB_USE_CUDA template using concat_block1 = con<5,1,1,1,1,SUBNET>; @@ -4756,6 +5042,8 @@ void test_multm_prev() compare_adam(); test_copy_tensor_gpu(); test_copy_tensor_add_to_gpu(); + test_copy_tensor_gpu(); + test_copy_tensor_add_to_gpu(); test_scale_channels(); #endif test_tensor_resize_bilinear(2, 3, 6,6, 11, 11); @@ -4810,6 +5098,8 @@ void test_multm_prev() test_visit_functions(); test_copy_tensor_cpu(); test_copy_tensor_add_to_cpu(); + test_copy_tensor_slice_cpu(); + test_copy_tensor_slice_add_to_cpu(); test_concat(); test_multm_prev(); test_simple_linear_regression();