diff --git a/dlib/cuda/cpu_dlib.cpp b/dlib/cuda/cpu_dlib.cpp index 0e5ca5cee6..145a1344fd 100644 --- a/dlib/cuda/cpu_dlib.cpp +++ b/dlib/cuda/cpu_dlib.cpp @@ -1620,122 +1620,176 @@ namespace dlib namespace ttimpl { - void softmax ( - const long num_locations, - const long num_channels, - tensor& dest, - const tensor& src - ) - { - DLIB_ASSERT(num_channels*num_locations == src.nr()*src.nc()*src.k()); - DLIB_CASSERT(have_same_dimensions(dest,src)); - const auto d = dest.host(); - const auto s = src.host(); + void softmax( + const long num_locations, + const long num_channels, + tensor& dest, + const tensor& src, + size_t mode = 0 + ) + { + DLIB_ASSERT(num_channels * num_locations == src.nr() * src.nc() * src.k()); + DLIB_CASSERT(have_same_dimensions(dest, src)); + const auto d = dest.host(); + const auto s = src.host(); - // Note that we subtract out the max values in each channel before applying - // exp() to avoid numeric overflow in the subsequent computations. Doing this - // doesn't change the resulting output, it just makes it more numerically - // stable. - for (long n = 0; n < src.num_samples(); ++n) - { - auto ss = s + num_locations*num_channels*n; - auto dd = d + num_locations*num_channels*n; - for (long i = 0; i < num_locations; ++i) + for (long n = 0; n < src.num_samples(); ++n) { - float max_val = -std::numeric_limits::infinity(); - for (long k = 0; k < num_channels; ++k) - max_val = std::max(max_val, ss[k*num_locations]); + auto ss = s + num_locations * num_channels * n; + auto dd = d + num_locations * num_channels * n; - for (long k = 0; k < num_channels; ++k) - dd[k*num_locations] = std::exp(ss[k*num_locations]-max_val); + if (mode == 0) // softmax_mode::CHANNEL_WISE + { + for (long i = 0; i < num_locations; ++i) + { + float max_val = -std::numeric_limits::infinity(); + for (long k = 0; k < num_channels; ++k) + max_val = std::max(max_val, ss[k * num_locations]); - ++ss; - ++dd; - } - } + float sum = 0.0f; + for (long k = 0; k < num_channels; ++k) + { + dd[k * num_locations] = std::exp(ss[k * num_locations] - max_val); + sum += dd[k * num_locations]; + } + for (long k = 0; k < num_channels; ++k) + dd[k * num_locations] /= sum; - // Now normalize each channel so they sum to 1. - for (long n = 0; n < src.num_samples(); ++n) - { - const auto dd = d + num_locations*num_channels*n; - for (long i = 0; i < num_locations; ++i) - { - const auto ddd = dd+i; + ++ss; + ++dd; + } + } + else if (mode == 1) // softmax_mode::PLANE_WISE + { + for (long k = 0; k < num_channels; ++k) + { + auto s_channel = ss + k * num_locations; + auto d_channel = dd + k * num_locations; + for (long r = 0; r < src.nr(); ++r) + { + float max_val = -std::numeric_limits::infinity(); + for (long c = 0, idx = r * src.nc(); c < src.nc(); ++c, ++idx) + max_val = std::max(max_val, s_channel[idx]); - float temp = 0; - for (long k = 0; k < num_channels; ++k) - temp += ddd[k*num_locations]; - for (long k = 0; k < num_channels; ++k) - ddd[k*num_locations] /= temp; + if (max_val == -std::numeric_limits::infinity()) + { + for (long c = 0, idx = r * src.nc(); c < src.nc(); ++c, ++idx) + d_channel[idx] = 0.0f; + } + else + { + float sum = 0.0f; + for (long c = 0, idx = r * src.nc(); c < src.nc(); ++c, ++idx) + { + d_channel[idx] = std::exp(s_channel[idx] - max_val); + sum += d_channel[idx]; + } + for (long c = 0, idx = r * src.nc(); c < src.nc(); ++c, ++idx) + d_channel[idx] /= sum; + } + } + } + } } } - } - void softmax_gradient ( - const long num_locations, - const long num_channels, - tensor& grad, - const tensor& dest, - const tensor& gradient_input - ) - { - DLIB_ASSERT(num_channels*num_locations == grad.nr()*grad.nc()*grad.k()); - DLIB_CASSERT(have_same_dimensions(grad,dest)); - DLIB_CASSERT(have_same_dimensions(grad,gradient_input)); - const auto d = dest.host(); - const auto g = grad.host(); - const auto in = gradient_input.host(); - - - for (long n = 0; n < grad.num_samples(); ++n) + void softmax_gradient( + const long num_locations, + const long num_channels, + tensor& grad, + const tensor& dest, + const tensor& gradient_input, + size_t mode = 0 + ) { - const auto d2 = d + num_locations*num_channels*n; - const auto g2 = g + num_locations*num_channels*n; - const auto in2 = in + num_locations*num_channels*n; - for (long i = 0; i < num_locations; ++i) + DLIB_ASSERT(num_channels * num_locations == grad.nr() * grad.nc() * grad.k()); + DLIB_CASSERT(have_same_dimensions(grad, dest)); + DLIB_CASSERT(have_same_dimensions(grad, gradient_input)); + + const auto d = dest.host(); + const auto g = grad.host(); + const auto in = gradient_input.host(); + for (long n = 0; n < grad.num_samples(); ++n) { - const auto d3 = d2+i; - const auto g3 = g2+i; - const auto in3 = in2+i; + const auto d2 = d + num_locations * num_channels * n; + const auto g2 = g + num_locations * num_channels * n; + const auto in2 = in + num_locations * num_channels * n; - float temp = 0; - for (long k = 0; k < num_channels; ++k) - temp += -d3[k*num_locations]*in3[k*num_locations]; - if (is_same_object(gradient_input, grad)) + if (mode == 0) // softmax_mode::CHANNEL_WISE { - for (long k = 0; k < num_channels; ++k) - g3[k*num_locations] = d3[k*num_locations]*(temp+in3[k*num_locations]); + for (long i = 0; i < num_locations; ++i) + { + const auto d3 = d2 + i; + const auto g3 = g2 + i; + const auto in3 = in2 + i; + float sum = 0.0f; + for (long k = 0; k < num_channels; ++k) + sum += -d3[k * num_locations] * in3[k * num_locations]; + if (is_same_object(gradient_input, grad)) + { + for (long k = 0; k < num_channels; ++k) + g3[k * num_locations] = d3[k * num_locations] * (sum + in3[k * num_locations]); + } + else + { + for (long k = 0; k < num_channels; ++k) + g3[k * num_locations] += d3[k * num_locations] * (sum + in3[k * num_locations]); + } + } } - else + else if (mode == 1) // softmax_mode::PLANE_WISE { for (long k = 0; k < num_channels; ++k) - g3[k*num_locations] += d3[k*num_locations]*(temp+in3[k*num_locations]); + { + const auto d_channel = d2 + k * num_locations; + const auto g_channel = g2 + k * num_locations; + const auto in_channel = in2 + k * num_locations; + for (long r = 0; r < grad.nr(); ++r) + { + float sum = 0.0f; + for (long c = 0, idx = r * grad.nc(); c < grad.nc(); ++c, ++idx) + sum += -d_channel[idx] * in_channel[idx]; + if (is_same_object(gradient_input, grad)) + { + for (long c = 0, idx = r * grad.nc(); c < grad.nc(); ++c, ++idx) + g_channel[idx] = d_channel[idx] * (sum + in_channel[idx]); + } + else + { + for (long c = 0, idx = r * grad.nc(); c < grad.nc(); ++c, ++idx) + g_channel[idx] += d_channel[idx] * (sum + in_channel[idx]); + } + } + } } } } } - } // ---------------------------------------------------------------------------------------- void softmax ( tensor& dest, - const tensor& src + const tensor& src, + size_t mode ) { DLIB_CASSERT(have_same_dimensions(dest,src)); - ttimpl::softmax(src.nr()*src.nc(), src.k(), dest, src); + DLIB_CASSERT(mode == 0 /*CHANNEL_WISE*/ || mode == 1 /*PLANE_WISE*/, "Invalid softmax mode"); + ttimpl::softmax(src.nr()*src.nc(), src.k(), dest, src, mode); } void softmax_gradient ( tensor& grad, const tensor& dest, - const tensor& gradient_input + const tensor& gradient_input, + size_t mode ) { DLIB_CASSERT(have_same_dimensions(grad,dest)); DLIB_CASSERT(have_same_dimensions(grad,gradient_input)); - ttimpl::softmax_gradient(grad.nr()*grad.nc(), grad.k(), grad, dest, gradient_input); + DLIB_CASSERT(mode == 0 /*CHANNEL_WISE*/ || mode == 1 /*PLANE_WISE*/, "Invalid softmax mode"); + ttimpl::softmax_gradient(grad.nr()*grad.nc(), grad.k(), grad, dest, gradient_input, mode); } // ------------------------------------------------------------------------------------ diff --git a/dlib/cuda/cpu_dlib.h b/dlib/cuda/cpu_dlib.h index f26795445d..d876801993 100644 --- a/dlib/cuda/cpu_dlib.h +++ b/dlib/cuda/cpu_dlib.h @@ -293,13 +293,15 @@ namespace dlib void softmax ( tensor& dest, - const tensor& src + const tensor& src, + size_t mode = 0 ); void softmax_gradient ( tensor& grad, const tensor& dest, - const tensor& gradient_input + const tensor& gradient_input, + size_t mode = 0 ); // ------------------------------------------------------------------------------------ diff --git a/dlib/cuda/cudnn_dlibapi.cpp b/dlib/cuda/cudnn_dlibapi.cpp index c09845cc03..dc4df68e88 100644 --- a/dlib/cuda/cudnn_dlibapi.cpp +++ b/dlib/cuda/cudnn_dlibapi.cpp @@ -1533,53 +1533,119 @@ namespace dlib void softmax ( tensor& dest, - const tensor& src + const tensor& src, + size_t mode ) { DLIB_CASSERT(have_same_dimensions(dest,src)); - if (src.size() == 0) - return; + DLIB_CASSERT(mode == 0 /*CHANNEL_WISE*/ || mode == 1 /*PLANE_WISE*/, "Invalid softmax mode"); + if (src.size() == 0) return; const float alpha = 1; const float beta = 0; - CHECK_CUDNN(cudnnSoftmaxForward(context(), - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &alpha, - descriptor(src), - src.device(), - &beta, - descriptor(dest), - dest.device())); - } + if (mode == 0) + { + CHECK_CUDNN(cudnnSoftmaxForward(ccontext(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + descriptor(src), + src.device(), + &beta, + descriptor(dest), + dest.device())); + } + else if (mode == 1) + { + const long num_samples = src.num_samples(); + const long num_channels = src.k(); + const size_t plane_size = src.nr() * src.nc(); + for (long s = 0; s < num_samples; ++s) + { + for (long k = 0; k < num_channels; ++k) + { + auto src_slice = src.device() + (s * num_channels + k) * plane_size; + auto dest_slice = dest.device() + (s * num_channels + k) * plane_size; + auto a_src_slice = alias_tensor(src.nr(), src.nc())(src, (s * num_channels + k) * plane_size); + auto a_dest_slice = alias_tensor(dest.nr(), dest.nc())(dest, (s * num_channels + k) * plane_size); + + CHECK_CUDNN(cudnnSoftmaxForward(ccontext(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + descriptor(a_src_slice), + src_slice, + &beta, + descriptor(a_dest_slice), + dest_slice)); + } + } + } + } void softmax_gradient ( tensor& grad, const tensor& dest, - const tensor& gradient_input + const tensor& gradient_input, + size_t mode ) { DLIB_CASSERT( have_same_dimensions(dest,gradient_input) == true && have_same_dimensions(dest,grad) == true ); - if (dest.size() == 0) - return; + DLIB_CASSERT(mode == 0 /*CHANNEL_WISE*/ || mode == 1 /*PLANE_WISE*/, "Invalid softmax mode"); + if (output.size() == 0) return; const float alpha = 1; - const float beta = is_same_object(grad,gradient_input) ? 0 : 1; - CHECK_CUDNN(cudnnSoftmaxBackward(context(), - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &alpha, - descriptor(dest), - dest.device(), - descriptor(gradient_input), - gradient_input.device(), - &beta, - descriptor(grad), - grad.device())); + const float beta = is_same_object(grad, gradient_input) ? 0 : 1; + + if (mode == 0) + { + CHECK_CUDNN(cudnnSoftmaxBackward(ccontext(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + descriptor(output), + output.device(), + descriptor(gradient_input), + gradient_input.device(), + &beta, + descriptor(grad), + grad.device())); + } + else if (mode == 1) + { + const long num_samples = output.num_samples(); + const long num_channels = output.k(); + const size_t plane_size = output.nr() * output.nc(); + + for (long s = 0; s < num_samples; ++s) + { + for (long k = 0; k < num_channels; ++k) + { + auto output_slice = output.device() + (s * num_channels + k) * plane_size; + auto gi_slice = gradient_input.device() + (s * num_channels + k) * plane_size; + auto grad_slice = grad.device() + (s * num_channels + k) * plane_size; + auto a_output_slice = alias_tensor(output.nr(), output.nc())(output, (s * num_channels + k) * plane_size); + auto a_gi_slice = alias_tensor(gradient_input.nr(), gradient_input.nc())(gradient_input, (s * num_channels + k) * plane_size); + auto a_grad_slice = alias_tensor(grad.nr(), grad.nc())(grad, (s * num_channels + k) * plane_size); + + CHECK_CUDNN(cudnnSoftmaxBackward(ccontext(), + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + descriptor(a_output_slice), + output_slice, + descriptor(a_gi_slice), + gi_slice, + &beta, + descriptor(a_grad_slice), + grad_slice)); + } + } + } } // ------------------------------------------------------------------------------------ diff --git a/dlib/cuda/cudnn_dlibapi.h b/dlib/cuda/cudnn_dlibapi.h index 7b040a00c2..051ae2314a 100644 --- a/dlib/cuda/cudnn_dlibapi.h +++ b/dlib/cuda/cudnn_dlibapi.h @@ -352,13 +352,15 @@ namespace dlib void softmax ( tensor& dest, - const tensor& src + const tensor& src, + size_t mode = 0 ); void softmax_gradient ( tensor& grad, const tensor& dest, - const tensor& gradient_input + const tensor& gradient_input, + size_t mode = 0 ); // ------------------------------------------------------------------------------------ diff --git a/dlib/cuda/tensor_tools.cpp b/dlib/cuda/tensor_tools.cpp index 90a09a2884..bdf46cb0b3 100644 --- a/dlib/cuda/tensor_tools.cpp +++ b/dlib/cuda/tensor_tools.cpp @@ -820,28 +820,30 @@ namespace dlib { namespace tt // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- - void softmax ( + void softmax( tensor& dest, - const tensor& src + const tensor& src, + size_t s_mode ) { #ifdef DLIB_USE_CUDA - cuda::softmax(dest,src); + cuda::softmax(dest, src, s_mode); #else - cpu::softmax(dest,src); + cpu::softmax(dest, src, s_mode); #endif } - void softmax_gradient ( + void softmax_gradient( tensor& grad, const tensor& dest, - const tensor& gradient_input + const tensor& gradient_input, + size_t s_mode ) { #ifdef DLIB_USE_CUDA - cuda::softmax_gradient(grad, dest, gradient_input); + cuda::softmax_gradient(grad, dest, gradient_input, s_mode); #else - cpu::softmax_gradient(grad, dest, gradient_input); + cpu::softmax_gradient(grad, dest, gradient_input, s_mode); #endif } diff --git a/dlib/cuda/tensor_tools.h b/dlib/cuda/tensor_tools.h index 8ea593a429..426240f2e0 100644 --- a/dlib/cuda/tensor_tools.h +++ b/dlib/cuda/tensor_tools.h @@ -1388,42 +1388,52 @@ namespace dlib { namespace tt void softmax ( tensor& dest, - const tensor& src + const tensor& src, + size_t mode = 0 ); /*! requires - have_same_dimensions(dest, src) == true + - mode == CHANNEL_WISE || mode == PLANE_WISE ensures - Note that the softmax function is a vector valued function: s(x) == exp(x)/sum(exp(x)) - - Computes the softmax function on src and writes the results to dest. The - softmax is computed per spatial location across the different channels at - each location. That is, softmax() outputs a new tensor, #dest, where each of - the spatial locations in dest (i.e. image idx, row idx, and column idx) - contains the output of s() evaluated over the channel values at each - location. + - Computes the softmax function on src and writes the results to dest. + - If mode == CHANNEL_WISE: + The softmax is computed per spatial location across the different channels at + each location. That is, softmax() outputs a new tensor, #dest, where each of + the spatial locations in dest (i.e. image idx, row idx, and column idx) + contains the output of s() evaluated over the channel values at each location. + - If mode == PLANE_WISE: + The softmax is computed across entire planes (nr x nc) of the input tensor. + This is useful for operations in Large Language Models (LLMs) and other + applications requiring 2D tensor processing. - This function supports in-place operation, i.e. having - is_same_object(dest, src)==true + is_same_object(dest, src)==true !*/ void softmax_gradient ( tensor& grad, const tensor& dest, - const tensor& gradient_input + const tensor& gradient_input, + size_t mode = 0 ); /*! requires - have_same_dimensions(dest,gradient_input) == true - have_same_dimensions(dest,grad) == true - ensures - - We interpret dest as the output of softmax(dest,SRC) for some SRC tensor. - Then let f(SRC) == dot(gradient_input,dest). Then this function computes the - gradient of f() with respect to SRC and stores it to grad. Moreover, if - is_same_object(grad,gradient_input)==true then the output is assigned to - grad, replacing its previous contents. Otherwise the output is added to - grad. + - mode == CHANNEL_WISE || mode == PLANE_WISE + ensures + - We interpret dest as the output of softmax(dest,SRC,mode) for some SRC tensor. + Then let f(SRC) == dot(gradient_input,dest). Then this function computes the + gradient of f() with respect to SRC and stores it to grad. Moreover, if + is_same_object(grad,gradient_input)==true then the output is assigned to + grad, replacing its previous contents. Otherwise the output is added to grad. + - The gradient computation takes into account the specified mode: + - If mode == CHANNEL_WISE: The gradient is computed per spatial location across channels. + - If mode == PLANE_WISE: The gradient is computed across entire planes of the tensor. - This function supports in-place operation, i.e. having - is_same_object(grad, gradient_input)==true + is_same_object(grad, gradient_input)==true !*/ // ---------------------------------------------------------------------------------------- diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index f34e7a8390..b5d78d2cb9 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -3985,31 +3985,30 @@ namespace dlib // ---------------------------------------------------------------------------------------- + enum softmax_mode { CHANNEL_WISE = 0, PLANE_WISE = 1 }; + + template class softmax_ { public: - softmax_() - { - } + softmax_() {} template - void setup (const SUBNET& /*sub*/) - { - } + void setup(const SUBNET& /*sub*/) {} void forward_inplace(const tensor& input, tensor& output) { - tt::softmax(output, input); - } + tt::softmax(output, input, s_mode_); + } void backward_inplace( const tensor& computed_output, - const tensor& gradient_input, - tensor& data_grad, - tensor& + const tensor& gradient_input, + tensor& data_grad, + tensor& /*params_grad*/ ) { - tt::softmax_gradient(data_grad, computed_output, gradient_input); + tt::softmax_gradient(data_grad, computed_output, gradient_input, s_mode_); } const tensor& get_layer_params() const { return params; } @@ -4025,26 +4024,29 @@ namespace dlib std::string version; deserialize(version, in); if (version != "softmax_") - throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_."); + throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::softmax_."); } friend std::ostream& operator<<(std::ostream& out, const softmax_& /*item*/) { - out << "softmax"; + out << "softmax (mode=" << (s_mode_ == CHANNEL_WISE ? "channel_wise" : "plane_wise") << ")"; return out; } friend void to_xml(const softmax_& /*item*/, std::ostream& out) { - out << "\n"; + out << "\n"; } private: - resizable_tensor params; + resizable_tensor params; // unused }; template - using softmax = add_layer; + using softmax = add_layer, SUBNET>; + + template + using softmaxm = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- @@ -5082,10 +5084,10 @@ namespace dlib using tril_mask = add_layer, SUBNET>; template - using tril_diag = add_layer, SUBNET>; - + using tril_diag = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- - + } #endif // DLIB_DNn_LAYERS_H_ \ No newline at end of file diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index 0d951e7804..f92574a3a8 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -2953,44 +2953,69 @@ namespace dlib // ---------------------------------------------------------------------------------------- + enum softmax_mode { CHANNEL_WISE = 0, PLANE_WISE = 1 }; + + template class softmax_ { /*! WHAT THIS OBJECT REPRESENTS This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface - defined above. In particular, it defines a softmax layer. To be precise, - we define the softmax function s(x) as: + defined above. It defines a softmax layer with two modes of operation: + channel-wise and plane-wise. + + The softmax function s(x) is defined as: s(x) == exp(x)/sum(exp(x)) - where x is a vector. Then this layer treats its input tensor as a - collection of multi-channel images and applies s() to each spatial location - in each image. In each application, the tensor::k() channel elements at - each position are input to s() and then replaced by the outputs of s(). + where x is a vector. + + 1. Channel-wise mode (s_mode_ == CHANNEL_WISE): + This mode treats the input tensor as a collection of multi-channel images + and applies s() to each spatial location in each image. The tensor::k() + channel elements at each position are input to s() and then replaced by + the outputs of s(). + + 2. Plane-wise mode (s_mode_ == PLANE_WISE): + This mode applies the softmax function across entire planes (nr x nc) of + the input tensor, useful for operations in Large Language Models (LLMs) + and other applications requiring 2D tensor processing. - This means that, for example, if you collapsed each output image to a 1 - channel image by adding the channels then you would end up with images - where each pixel value was 1. This is because the sum of the outputs of - s() will always be equal to 1. + In both modes, the sum of the outputs of s() will always be equal to 1 for + each application of the function. + + TEMPLATE PARAMETERS + - s_mode_: Determines the mode of operation (CHANNEL_WISE or PLANE_WISE) !*/ public: + softmax_(); - softmax_( - ); - - template void setup (const SUBNET& sub); + template void setup(const SUBNET& sub); void forward_inplace(const tensor& input, tensor& output); - void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); - const tensor& get_layer_params() const; - tensor& get_layer_params(); + void backward_inplace( + const tensor& computed_output, + const tensor& gradient_input, + tensor& data_grad, + tensor& params_grad + ); + const tensor& get_layer_params() const; + tensor& get_layer_params(); /*! These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ - interface. Note that this layer doesn't have any parameters, so the tensor + interface. Note that this layer doesn't have any parameters, so the tensor returned by get_layer_params() is always empty. !*/ + + friend void serialize(const softmax_& item, std::ostream& out); + friend void deserialize(softmax_& item, std::istream& in); + friend std::ostream& operator<<(std::ostream& out, const softmax_& item); + friend void to_xml(const softmax_& item, std::ostream& out); }; template - using softmax = add_layer; + using softmax = add_layer, SUBNET>; + + template + using softmaxm = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- @@ -4088,8 +4113,7 @@ namespace dlib using tril_diag = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- - + } -#endif // DLIB_DNn_LAYERS_ABSTRACT_H_ - +#endif // DLIB_DNn_LAYERS_ABSTRACT_H_ \ No newline at end of file diff --git a/dlib/dnn/visitors.h b/dlib/dnn/visitors.h index 726f3b200e..f7ad26ae77 100644 --- a/dlib/dnn/visitors.h +++ b/dlib/dnn/visitors.h @@ -962,8 +962,8 @@ namespace dlib update(i); } - template - void operator()(size_t i, const add_layer&) + template + void operator()(size_t i, const add_layer, U, E>&) { start_node(i, "softmax"); end_node(); diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 6d3c6c94b4..713aae4485 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -128,8 +128,6 @@ namespace // fill like this as a test of the assignment operator. gradient_input = matrix_cast(gaussian_randm(5,5*nr*nc, 2)); - - auto grad_src = [&](long idx) { auto f = [&](float eps) { const float old = src.host()[idx]; @@ -166,6 +164,85 @@ namespace #endif } + void test_softmaxm() + { + print_spinner(); + using net_type = tag1>>>>; + net_type net; + + // Initialization + dlib::rand rnd(std::rand()); + const long nr = 2, nc = 3; + const int n_samples = 3, k = 1; + std::vector> x(n_samples); + matrix xtmp(nr, nc); + for (int ii = 0; ii < n_samples; ++ii) { + for (int jj = 0; jj < nr; ++jj) + for (int kk = 0; kk < nc; ++kk) { + float r = rnd.get_random_gaussian(); + if (r > 1 || r < -1) r = -std::numeric_limits::infinity(); + xtmp(jj, kk) = r; + } + x[ii] = xtmp; + } + + // Convert input matrix to tensor + resizable_tensor input_tensor; + net.to_tensor(&x[0], &x[0] + n_samples, input_tensor); + net.forward(input_tensor); + + // Expected output tensor + resizable_tensor expected_output; + expected_output.copy_size(input_tensor); + for (int ii = 0; ii < n_samples; ++ii) { + for (int jj = 0; jj < nr; ++jj) { + matrix m(1, nc); + bool all_neg_inf = true; + for (int kk = 0; kk < nc; ++kk) { + m(0, kk) = input_tensor.host()[tensor_index(input_tensor, ii, 0, jj, kk)]; + if (m(0, kk) > -std::numeric_limits::infinity()) all_neg_inf = false; + } + + matrix r(1, nc); + if (all_neg_inf) + for (int kk = 0; kk < nc; ++kk) r(0, kk) = 0.0f; + else { + // Stabilize the computation by subtracting the max value + float max_val = max(m); + matrix exp_m = exp(m - max_val); + float sum_exp = sum(exp_m) + std::numeric_limits::epsilon(); + r = exp_m / sum_exp; + } + for (int kk = 0; kk < nc; ++kk) + expected_output.host()[tensor_index(expected_output, ii, 0, jj, kk)] = r(0, kk); + } + } + + // Compare output tensor with expected output + auto& net_output = layer(net).get_output(); + DLIB_TEST(max(abs(mat(net_output) - mat(expected_output))) < 1e-5); + + // Compare CPU and CUDA utility functions + resizable_tensor output_tensor, cpu_grad, gradient_input; + output_tensor.copy_size(input_tensor); + cpu_grad.copy_size(input_tensor); + cpu_grad = 0; + gradient_input.copy_size(input_tensor); + randomize_parameters(gradient_input, nr + nc, rnd); + cpu::softmax(output_tensor, input_tensor, 1); + cpu::softmax_gradient(cpu_grad, output_tensor, gradient_input, 1); + DLIB_TEST(max(abs(mat(output_tensor) - mat(expected_output))) < 1e-5); +#ifdef DLIB_USE_CUDA + resizable_tensor cuda_grad; + cuda_grad.copy_size(input_tensor); + cuda_grad = 0; + cuda::softmax(output_tensor, input_tensor, 1); + cpu::softmax_gradient(cuda_grad, output_tensor, gradient_input, 1); + DLIB_TEST(max(abs(mat(output_tensor) - mat(expected_output))) < 1e-5); + DLIB_TEST(max(abs(mat(cuda_grad) - mat(cpu_grad))) < 1e-5); +#endif + } + void test_softmax_all() { using namespace dlib::tt; @@ -2390,10 +2467,16 @@ void test_embeddings() } { print_spinner(); - softmax_ l; + softmax_ l; auto res = test_layer(l); DLIB_TEST_MSG(res, res); } + { + print_spinner(); + softmax_ l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } { print_spinner(); softmax_all_ l; @@ -2417,7 +2500,7 @@ void test_embeddings() embeddings_<7, 12> l; auto res = test_layer(l); DLIB_TEST_MSG(res, res); - } + } } // ---------------------------------------------------------------------------------------- @@ -4655,6 +4738,7 @@ void test_embeddings() test_avg_pool(4,5,40,50,0,1); test_tanh(); test_softmax(); + test_softmaxm(); test_softmax_all(); test_sigmoid(); test_mish();