Skip to content

Add positional_encodings_ layer to Dlib #3019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4704,18 +4704,19 @@ namespace dlib
sequence_dim(sequence_dim_), embedding_dim(embedding_dim_)
{
}
positional_encodings_(const positional_encodings_& item) :
positional_encodings_(const positional_encodings_& item) :
pe(item.pe), sequence_dim(item.sequence_dim), embedding_dim(item.embedding_dim)
{
}
positional_encodings_& operator= (const positional_encodings_& item) {
positional_encodings_& operator= (const positional_encodings_& item)
{
if (this == &item) return *this;
pe = item.pe;
sequence_dim = item.sequence_dim;
embedding_dim = item.embedding_dim;
return *this;
}

template <typename SUBNET>
void setup(const SUBNET& sub)
{
Expand All @@ -4727,7 +4728,7 @@ namespace dlib
const unsigned long nk = prev.k();
const float n = 10000.0f;

pe.set_size(ns, nk, sequence_dim, embedding_dim);
pe.set_size(ns, nk, sequence_dim, embedding_dim);
for (unsigned long s = 0; s < ns; ++s)
{
for (unsigned long k = 0; k < nk; ++k)
Expand All @@ -4744,13 +4745,12 @@ namespace dlib
}
}
}

template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
const auto& prev_output = sub.get_output();
if (!have_same_dimensions(pe, prev_output)) setup(sub);

{
const auto& prev_output = sub.get_output();
if (!have_same_dimensions(prev_output, pe)) setup(sub);
output.set_size(prev_output.num_samples(), prev_output.k(), sequence_dim, embedding_dim);
tt::add(output, prev_output, pe);
}
Expand Down Expand Up @@ -4785,8 +4785,7 @@ namespace dlib
out << "positional_encodings";
return out;
}
friend void to_xml(const positional_encodings_& /*item*/, std::ostream& out)
{
friend void to_xml(const positional_encodings_& /*item*/, std::ostream& out) {
out << "<positional_encodings />\n";
}

Expand Down
149 changes: 147 additions & 2 deletions dlib/dnn/layers_abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -4089,7 +4089,152 @@ namespace dlib

// ----------------------------------------------------------------------------------------

}
class positional_encodings_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
It defines a positional encoding layer that adds position information to
the input tensor. This is particularly useful in transformer architectures
where the order of the sequence matters.

The dimensions of the tensors output by this layer are the same as the input
tensor dimensions.

This implementation is based on the positional encoding described in:
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N.,
Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In Advances
in neural information processing systems (pp. 5998-6008).

The encoding uses sine and cosine functions of different frequencies:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
where pos is the position and i is the dimension.
!*/

public:

positional_encodings_(
unsigned long sequence_dim_ = 1,
unsigned long embedding_dim_ = 1
);
/*!
ensures
- #sequence_dim == sequence_dim_
- #embedding_dim == embedding_dim_
!*/

positional_encodings_ (
const positional_encodings_& item
);
/*!
ensures
- EXAMPLE_COMPUTATIONAL_LAYER_ objects are copy constructable
!*/

#endif // DLIB_DNn_LAYERS_ABSTRACT_H_
positional_encodings_& operator=(
const positional_encodings_& item
);
/*!
ensures
- EXAMPLE_COMPUTATIONAL_LAYER_ objects are assignable
!*/

template <typename SUBNET>
void setup (
const SUBNET& sub
);
/*!
requires
- SUBNET implements the SUBNET interface defined at the top of this file.
ensures
- performs any necessary setup for the layer, including the calculation
of positional encodings based on the dimensions of the input.
!*/

template <typename SUBNET>
void forward(
const SUBNET& sub,
resizable_tensor& output
);
/*!
requires
- SUBNET implements the SUBNET interface defined at the top of this file.
- setup() has been called.
ensures
- Adds the positional encodings to the output of the subnetwork and
stores the results into #output.
!*/

template <typename SUBNET>
void backward(
const tensor& gradient_input,
SUBNET& sub,
tensor& params_grad
);
/*!
requires
- SUBNET implements the SUBNET interface defined at the top of this file.
- setup() has been called.
- #params_grad is unused in this layer as there are no learnable parameters.
ensures
- Computes the gradient of the layer with respect to the input, which
is simply the input gradient itself as positional encodings are constant.
!*/

const tensor& get_layer_params(
) const;
/*!
ensures
- returns the parameters that define the behavior of forward().
Note: This layer has no learnable parameters, so this returns an empty tensor.
!*/

tensor& get_layer_params(
);
/*!
ensures
- returns the parameters that define the behavior of forward().
Note: This layer has no learnable parameters, so this returns an empty tensor.
!*/

const tensor& get_positional_encodings(
) const;
/*!
ensures
- returns the computed positional encodings.
!*/

tensor& get_positional_encodings(
);
/*!
ensures
- returns the computed positional encodings.
!*/

friend void serialize(const positional_encodings_& item, std::ostream& out);
friend void deserialize(positional_encodings_& item, std::istream& in);
/*!
provides serialization support
!*/

friend std::ostream& operator<<(std::ostream& out, const positional_encodings_& item);
/*!
print a string describing this layer.
!*/

friend void to_xml(const positional_encodings_& item, std::ostream& out);
/*!
This function is optional, but required if you want to print your networks with
net_to_xml(). It prints a layer as XML.
!*/
};

template <typename SUBNET>
using positional_encodings = add_layer<positional_encodings_, SUBNET>;

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_DNn_LAYERS_ABSTRACT_H_
18 changes: 9 additions & 9 deletions dlib/dnn/visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,20 +1029,20 @@ namespace dlib
update(i);
}

template <unsigned long ne, unsigned long ed, typename U, typename E>
void operator()(size_t i, const add_layer<embeddings_<ne, ed>, U, E>& l)
template <typename U, typename E>
void operator()(size_t i, const add_layer<positional_encodings_, U, E>&)
{
start_node(i, "embeddings");
out << " | {num_embeddings|{" << l.layer_details().get_num_embeddings() << "}}";
out << " | {embedding_dim|{" << l.layer_details().get_embedding_dim() << "}}";
start_node(i, "positional_encodings");
end_node();
update(i);
}

template <typename U, typename E>
void operator()(size_t i, const add_layer<positional_encodings_, U, E>&)
template <unsigned long ne, unsigned long ed, typename U, typename E>
void operator()(size_t i, const add_layer<embeddings_<ne, ed>, U, E>& l)
{
start_node(i, "positional_encodings");
start_node(i, "embeddings");
out << " | {num_embeddings|{" << l.layer_details().get_num_embeddings() << "}}";
out << " | {embedding_dim|{" << l.layer_details().get_embedding_dim() << "}}";
end_node();
update(i);
}
Expand All @@ -1061,7 +1061,7 @@ namespace dlib
out << "}}";
end_node();
update(i);
}
}

template <typename T, typename U, typename E>
void operator()(size_t i, const add_layer<T, U, E>&)
Expand Down
68 changes: 34 additions & 34 deletions dlib/test/dnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,39 +781,6 @@ namespace

// ----------------------------------------------------------------------------------------

void test_positional_encodings()
{
print_spinner();
using net_type = tag1<positional_encodings<input<matrix<float>>>>;
net_type net;

const unsigned long sequence_dim = 4;
const unsigned long embedding_dim = 6;
const unsigned long n_samples = 1, n_channels = 1;
matrix<float> input_data(sequence_dim, embedding_dim);
input_data = 0.0f;

resizable_tensor input_tensor(n_samples, n_channels, sequence_dim, embedding_dim);
std::vector<matrix<float>> x(n_samples);
x[0] = input_data;
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
net.forward(input_tensor);

matrix<float> expected_output(sequence_dim, embedding_dim);
const float n = 10000.0f;
for (long r = 0; r < sequence_dim; ++r) {
for (long c = 0; c < embedding_dim; ++c) {
float theta = static_cast<float>(r) / std::pow(n, static_cast<float>(c) / embedding_dim);
expected_output(r, c) = (c % 2 == 0) ? std::sin(theta) : std::cos(theta);
}
}

auto& net_output = layer<tag1>(net).get_output();
DLIB_TEST(max(abs(mat(net_output) - expected_output)) < 1e-5);
}

// ----------------------------------------------------------------------------------------

void test_embeddings()
{
print_spinner();
Expand Down Expand Up @@ -862,6 +829,39 @@ void test_embeddings()
DLIB_TEST(acc > 0.9);
}

// ----------------------------------------------------------------------------------------

void test_positional_encodings()
{
print_spinner();
using net_type = tag1<positional_encodings<input<matrix<float>>>>;
net_type net;

const unsigned long sequence_dim = 4;
const unsigned long embedding_dim = 6;
const unsigned long n_samples = 1, n_channels = 1;
matrix<float> input_data(sequence_dim, embedding_dim);
input_data = 0.0f;

resizable_tensor input_tensor(n_samples, n_channels, sequence_dim, embedding_dim);
std::vector<matrix<float>> x(n_samples);
x[0] = input_data;
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
net.forward(input_tensor);

matrix<float> expected_output(sequence_dim, embedding_dim);
const float n = 10000.0f;
for (long r = 0; r < sequence_dim; ++r) {
for (long c = 0; c < embedding_dim; ++c) {
float theta = static_cast<float>(r) / std::pow(n, static_cast<float>(c) / embedding_dim);
expected_output(r, c) = (c % 2 == 0) ? std::sin(theta) : std::cos(theta);
}
}

auto& net_output = layer<tag1>(net).get_output();
DLIB_TEST(max(abs(mat(net_output) - expected_output)) < 1e-5);
}

// ----------------------------------------------------------------------------------------

void test_basic_tensor_ops()
Expand Down Expand Up @@ -2111,7 +2111,7 @@ void test_embeddings()
tril_<-5, void, 1, 2> l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
}
{
print_spinner();
extract_<0,2,2,2> l;
Expand Down
Loading