Skip to content

Add tril_ layer for lower triangular matrix operations #3018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4696,6 +4696,132 @@ namespace dlib

template <typename SUBNET> using transpose = add_layer<transpose_, SUBNET>;

// ----------------------------------------------------------------------------------------

struct neg_infinity_tag {};
struct zero_tag {};

template<typename T>
struct is_special_value : std::false_type {};
template<>
struct is_special_value<neg_infinity_tag> : std::true_type {};
template<>
struct is_special_value<zero_tag> : std::true_type {};

template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
class tril_
{
public:
tril_(): diag(diag_), diag_value(compute_diag_value()) {}

template <typename SUBNET>
void setup(const SUBNET& /*sub*/)
{
}

template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto& prev = sub.get_output();
output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc());

check_mask(prev);
tt::multiply(false, output, prev, binary_mask);
if (diag_value != 0.0f) tt::add(1, output, 1, output_mask);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{
auto& prev_grad = sub.get_gradient_input();
tt::multiply(true, prev_grad, gradient_input, binary_mask);
}

inline dpoint map_input_to_output(const dpoint& p) const { return p; }
inline dpoint map_output_to_input(const dpoint& p) const { return p; }

const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }

friend void serialize(const tril_& item, std::ostream& out)
{
serialize("tril_", out);
serialize(item.diag, out);
serialize(item.diag_value, out);
}
friend void deserialize(tril_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "tril_")
throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_.");
deserialize(item.diag, in);
deserialize(item.diag_value, in);
}

friend std::ostream& operator<<(std::ostream& out, const tril_& item)
{
out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")";
return out;
}
friend void to_xml(const tril_& item, std::ostream& out)
{
out << "<tril diag='" << item.diag << "' diag_value='" << item.diag_value << "'/>\n";
}

private:
float compute_diag_value() const {
if (std::is_same<tag_, neg_infinity_tag>::value)
return -std::numeric_limits<float>::infinity();
else if (std::is_same<tag_, zero_tag>::value)
return 0.0f;
else
return static_cast<float>(num_) / static_cast<float>(den_);
}

void check_mask(const tensor& t)
{
if (!have_same_dimensions(binary_mask, t)) {
binary_mask.copy_size(t);
binary_mask = 1;
if (diag_value != 0.0f) {
output_mask.copy_size(t);
output_mask = 0;
}
for (long s = 0; s < output_mask.num_samples(); ++s)
{
for (long k = 0; k < output_mask.k(); ++k)
{
for (long r = 0; r < output_mask.nr(); ++r)
{
for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c)
{
if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value;
binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0;
}
}
}
}
}
}

template <typename T>
struct always_false : std::false_type {};

resizable_tensor params; // unused
resizable_tensor binary_mask, output_mask;
long diag;
float diag_value;
};

template <typename SUBNET>
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;

template <typename SUBNET>
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;

template <long diag, long num, long den, typename SUBNET>
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;

// ----------------------------------------------------------------------------------------

}
Expand Down
156 changes: 156 additions & 0 deletions dlib/dnn/layers_abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -3711,6 +3711,162 @@ namespace dlib
template <typename SUBNET>
using transpose = add_layer<transpose_, SUBNET>;

// ----------------------------------------------------------------------------------------

struct neg_infinity_tag {};
struct zero_tag {};

template<typename T>
struct is_special_value : std::false_type {};
template<>
struct is_special_value<neg_infinity_tag> : std::true_type {};
template<>
struct is_special_value<zero_tag> : std::true_type {};

template<long diag_, typename tag_, long num_ = 0, long den_ = 1>
class tril_
{
/*!
TEMPLATE PARAMETERS
- diag_: A long integer specifying the diagonal offset.
- tag_: A type tag specifying special values or void for numeric values.
- num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void).
- den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void).

REQUIREMENTS
- diag_ must be an integer.
- tag_ must be either neg_infinity_tag, zero_tag, or void.
- If tag_ is void, num_ and den_ are used to compute the diagonal value.
- If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored.

WHAT THIS OBJECT REPRESENTS
This object implements a layer in a deep neural network that applies a lower triangular mask to
its input tensor. The mask is defined such that all elements above the specified diagonal are set
to a given value. The diagonal offset and the mask value are determined by the template parameters.

DIAGONAL VALUE DETERMINATION
- If tag_ is neg_infinity_tag: diagonal value is set to negative infinity.
- If tag_ is zero_tag: diagonal value is set to zero.
- If tag_ is void: diagonal value is set to num_ / den_ as a float.

DIAGONAL OFFSET
The diag_ parameter determines the diagonal above which elements are masked:
- diag_ = 0: main diagonal
- diag_ > 0: diag_ steps above the main diagonal
- diag_ < 0: |diag_| steps below the main diagonal

EXAMPLE USAGE
// Create a layer that masks all elements above the main diagonal with -inf
tril_<0, neg_infinity_tag> layer1;

// Create a layer that masks all elements above the main diagonal with 0
tril_<0, zero_tag> layer2;

// Create a layer that masks all elements above the main diagonal with 0.5
tril_<0, void, 1, 2> layer3;

// Create a layer that masks all elements 5 positions above the main diagonal with -inf
tril_<5, neg_infinity_tag> layer4;

// Create a layer that masks all elements 3 positions below the main diagonal with 0.25
tril_<-3, void, 1, 4> layer5;

SERIALIZATION SUPPORT
This object supports serialization and deserialization via the serialize() and deserialize() functions.
!*/

public:
tril_() = default;
/*!
ensures
- This object is properly initialized.
!*/

template <typename SUBNET>
void setup(const SUBNET& sub);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Initializes the mask based on the dimensions of the input tensor from sub.
!*/

template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Applies the lower triangular mask to the input tensor from sub and stores the result in output.
!*/

template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
/*!
requires
- SUBNET is a valid network layer type.
ensures
- Computes the gradient of the loss with respect to the input tensor and stores it in sub.
!*/

inline dpoint map_input_to_output(const dpoint& p) const;
/*!
ensures
- Maps a point from the input tensor to the corresponding point in the output tensor.
!*/

inline dpoint map_output_to_input(const dpoint& p) const;
/*!
ensures
- Maps a point from the output tensor to the corresponding point in the input tensor.
!*/

const tensor& get_layer_params() const;
/*!
ensures
- Returns the parameters of this layer.
!*/

tensor& get_layer_params();
/*!
ensures
- Returns the parameters of this layer.
!*/

friend void serialize(const tril_& item, std::ostream& out);
/*!
ensures
- Serializes the state of this object to the given output stream.
!*/

friend void deserialize(tril_& item, std::istream& in);
/*!
ensures
- Deserializes the state of this object from the given input stream.
!*/

friend std::ostream& operator<<(std::ostream& out, const tril_& item);
/*!
ensures
- Prints a human-readable representation of this object to the given output stream.
!*/

friend void to_xml(const tril_& item, std::ostream& out);
/*!
ensures
- Serializes the state of this object to XML format and writes it to the given output stream.
!*/
};

template <typename SUBNET>
using tril = add_layer<tril_<0, zero_tag>, SUBNET>;

template <typename SUBNET>
using tril_mask = add_layer<tril_<0, neg_infinity_tag>, SUBNET>;

template <long diag, long num, long den, typename SUBNET>
using tril_diag = add_layer<tril_<diag, void, num, den>, SUBNET>;

// ----------------------------------------------------------------------------------------

}
Expand Down
16 changes: 16 additions & 0 deletions dlib/dnn/visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,22 @@ namespace dlib
update(i);
}

template <long diag, typename tag, long num, long den, typename U, typename E>
void operator()(size_t i, const add_layer<tril_<diag, tag, num, den>, U, E>&)
{
start_node(i, "tril");
out << " | {diag|{" << diag << "}}";
out << " | {diag_value|{";

if (std::is_same<tag, neg_infinity_tag>::value) out << "-inf";
else if (std::is_same<tag, zero_tag>::value) out << "0";
else out << static_cast<float>(num) / static_cast<float>(den);

out << "}}";
end_node();
update(i);
}

template <typename T, typename U, typename E>
void operator()(size_t i, const add_layer<T, U, E>&)
{
Expand Down
48 changes: 48 additions & 0 deletions dlib/test/dnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,12 @@ namespace
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
tril_<-5, void, 1, 2> l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
{
print_spinner();
extract_<0,2,2,2> l;
Expand Down Expand Up @@ -4447,6 +4453,47 @@ namespace
}
}

// ----------------------------------------------------------------------------------------

void test_tril()
{
print_spinner();
using net_type = tag1<tril_mask<tag2<input<matrix<float>>>>>;
net_type net;

// Input tensor
dlib::rand rnd;
const int nr = 2, nc = 3;
constexpr int n_samples = 3, k = 1;
std::vector<matrix<float>> x(n_samples);
matrix<float> xtmp(nr, nc);
for (int ii = 0; ii < n_samples; ++ii) {
for (int jj = 0; jj < nr; ++jj)
for (int kk = 0; kk < nc; ++kk)
xtmp(jj, kk) = rnd.get_random_gaussian();
x[ii] = xtmp;
}

// Convert input matrix to tensor
resizable_tensor input_tensor;
net.to_tensor(&x[0], &x[0] + n_samples, input_tensor);
net.forward(input_tensor);

// Expected output tensor (manually set for comparison)
resizable_tensor expected_output;
expected_output.copy_size(input_tensor);
tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k());
for (int ii = 0; ii < n_samples; ++ii) {
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits<float>::infinity();
}

// Compare output tensor with expected output
auto& net_output = layer<tag1>(net).get_output();
DLIB_TEST(max(abs(mat(net_output) - mat(expected_output))) < 1e-5);
}

// ----------------------------------------------------------------------------------------

class dnn_tester : public tester
Expand Down Expand Up @@ -4527,6 +4574,7 @@ namespace
test_layer_normalize();
test_rms_normalize();
test_transpose();
test_tril();
test_basic_tensor_ops();
test_layers();
test_visit_functions();
Expand Down
Loading