Skip to content

Add tril_ layer for lower triangular matrix operations #3018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 30, 2024

Conversation

Cydral
Copy link
Contributor

@Cydral Cydral commented Sep 23, 2024

This PR introduces a new tril_ layer to dlib, which implements lower triangular matrix operations similar to PyTorch's torch.tril() function. The layer allows for flexible lower triangular matrix masking with customizable diagonal offset and diagonal value.

Key features:

  • Implements lower triangular masking for tensors
  • Supports custom diagonal offset and diagonal value
  • Offers three convenient alias templates: tril, tril_mask, and tril_diag

This addition enhances dlib's neural network capabilities, allowing for more complex architectures that require lower triangular matrix operations.
The new layer can be particularly useful in attention mechanisms, triangular matrix operations, and other scenarios where lower triangular masking is required in neural network architectures.

@Cydral
Copy link
Contributor Author

Cydral commented Sep 23, 2024

In C++17, the use of a float in a template is not allowed. Therefore, I used an intermediate structure to pass the float to the tril_ layer. However, the GCC compilation still fails. Is there a ‘trick’ in Dlib’ to do this?

@arrufat
Copy link
Contributor

arrufat commented Sep 24, 2024

I didn't check in detail what you're doing, but the one time I needed floats in a template parameter, I just used two integers (numerator and denominator). Similar to what std::ratio does.

In your case, something like:

template<long num, long den>
class tril_
{
public:
    tril_() : diag(static_cast<float>(num) / static_cast<float>(den)) {}
private:
    float diag;
};

@Cydral
Copy link
Contributor Author

Cydral commented Sep 24, 2024

I didn't check in detail what you're doing, but the one time I needed floats in a template parameter, I just used two integers (numerator and denominator). Similar to what std::ratio does.

In your case, something like:

template<long num, long den>
class tril_
{
public:
    tril_() : diag(static_cast<float>(num) / static_cast<float>(den)) {}
private:
    float diag;
};

Thank you for your feedback, @arrufat. I appreciate your suggestion, and I've actually explored various approaches during this refactoring process. C++17 compatibility would indeed make the code more efficient to write, but given our current constraints, I've had to find a C++14 compatible solution.

In the end, I've implemented a mechanism very similar to what you've described. The refactored tril_ class now uses a combination of tags (for specific values like negative infinity) and a numerator/denominator approach for other numeric values.
The static compilation tests now pass. My own unit tests show quite satisfactory functionality... although I think direct float precision would have been preferable...

@davisking
Copy link
Owner

Looks good, thanks for the PR :)

@davisking davisking merged commit 4e53f83 into davisking:master Sep 30, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants