Skip to content

Add transformer example with RoPE and MoE-like mechanisms #3078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3e9b9f1
Implementation of linear_ layer for neural networks. This layer provi…
Cydral Apr 28, 2025
93ead3d
Minor change
Cydral May 2, 2025
bf1b805
Update dlib/dnn/layers.h
davisking May 3, 2025
49bfbc6
Merge branch 'davisking:master' into master
Cydral May 6, 2025
f234faa
Add reshape_to and flatten layers to Dlib's DNN module
Cydral May 6, 2025
26a2960
Missing update to "visitors.h"
Cydral May 22, 2025
c9a1ee4
format fixing for reshape_to
Cydral May 22, 2025
02e62d8
Update dlib/test/dnn.cpp
davisking May 23, 2025
394dee8
Merge branch 'davisking:master' into master
Cydral May 29, 2025
778bfc1
Vocabulary size fixed for learning, and function added for transforma…
Cydral May 29, 2025
03aafc2
Added a new example for learning a “complex” Transformer model.
Cydral May 29, 2025
22c2561
Added a new example for learning a “complex” Transformer model.
Cydral May 29, 2025
01cd0b2
Updated example for training a Transformer model.
Cydral May 29, 2025
6b63e55
fix for gcc/ffmpeg compilation
Cydral May 30, 2025
ad1f757
Fix a warning message for Ubuntu compilation.
Cydral May 30, 2025
c91c45a
Update for Linux environment.
Cydral May 30, 2025
6fcc0aa
Fix batch building
Cydral May 31, 2025
5a1773e
Slight improvement in model definition.
Cydral Jun 3, 2025
10d7c59
linear_ layer implementation improvement
Cydral Jun 7, 2025
d4bf94b
finalizing the example
Cydral Jun 7, 2025
a4dac0b
Fixing break condition in training method.
Cydral Jun 8, 2025
63454e3
Fixing declaration order of variables.
Cydral Jun 8, 2025
87ed70a
bpe_tokenizer improvements.
Cydral Jun 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -2329,24 +2329,46 @@ namespace dlib

template <
unsigned long num_outputs_,
linear_bias_mode bias_mode_
linear_bias_mode bias_mode_ = LINEAR_HAS_BIAS
>
class linear_
{
static_assert(num_outputs_ > 0, "The number of outputs from a linear_ layer must be > 0");

public:
linear_() :
explicit linear_() :
num_outputs(num_outputs_),
num_inputs(0),
num_inputs(0),
learning_rate_multiplier(1),
bias_mode(bias_mode_) {
}

linear_(const linear_& other) :
num_outputs(other.num_outputs),
num_inputs(other.num_inputs),
learning_rate_multiplier(other.learning_rate_multiplier),
bias_mode(other.bias_mode),
params(other.params),
weights(other.weights),
biases(other.biases) {
}

linear_& operator=(const linear_& other) {
if (this != &other) {
num_outputs = other.num_outputs;
num_inputs = other.num_inputs;
learning_rate_multiplier = other.learning_rate_multiplier;
bias_mode = other.bias_mode;
params = other.params;
weights = other.weights;
biases = other.biases;
}
return *this;
}

double get_learning_rate_multiplier() const { return learning_rate_multiplier; }
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }

unsigned long get_num_inputs() const { return num_inputs; }

unsigned long get_num_outputs() const { return num_outputs; }
void set_num_outputs(long num)
{
Expand All @@ -2358,6 +2380,7 @@ namespace dlib
num_outputs = num;
}
}
unsigned long get_num_inputs() const { return num_inputs; }
linear_bias_mode get_bias_mode() const { return bias_mode; }

template <typename SUBNET>
Expand Down Expand Up @@ -2503,8 +2526,8 @@ namespace dlib
}

private:
unsigned long num_inputs;
unsigned long num_outputs;
unsigned long num_inputs;
double learning_rate_multiplier;
linear_bias_mode bias_mode;
resizable_tensor params;
Expand All @@ -2515,7 +2538,7 @@ namespace dlib
unsigned long num_outputs,
typename SUBNET
>
using linear = add_layer<linear_<num_outputs, LINEAR_HAS_BIAS>, SUBNET>;
using linear = add_layer<linear_<num_outputs>, SUBNET>;

template <
unsigned long num_outputs,
Expand Down
166 changes: 98 additions & 68 deletions dlib/tokenizer/bpe_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,32 @@

namespace dlib
{
constexpr size_t BPE_TOKENIZER_MAX_TOKEN_LENGTH = 8;
constexpr int BPE_TOKENIZER_BASE_VOCAB_SIZE = 256;

class bpe_tokenizer
{
public:
bpe_tokenizer() : vocab_size(BPE_TOKENIZER_BASE_VOCAB_SIZE)
bpe_tokenizer() : vocab_size(BASE_VOCAB_SIZE)
{
// Initialize the base vocabulary with single bytes
for (int i = 0; i < BPE_TOKENIZER_BASE_VOCAB_SIZE; ++i)
for (int i = 0; i < BASE_VOCAB_SIZE; ++i)
vocab[i] = std::vector<uint8_t>{ static_cast<uint8_t>(i) };

// Initialize special tokens with sequential IDs
special_tokens =
{
{"<text>", BPE_TOKENIZER_BASE_VOCAB_SIZE},
{"</text>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 1},
{"<url>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 2},
{"</url>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 3},
{"<image>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 4},
{"</image>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 5},
{"<video>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 6},
{"</video>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 7},
{"<audio>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 8},
{"</audio>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 9},
{"<file>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 10},
{"</file>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 11},
{"<code>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 12},
{"</code>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 13},
{"<summary>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 14},
{"</summary>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 15},
{"<think>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 16},
{"</think>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 17},
{"<start>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 18},
{"<end>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 19},
{"<user>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 20},
{"<bot>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 21},
{"<system>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 22},
{"<question>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 23},
{"<answer>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 24},
{"<search>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 25},
{"<unk>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 26},
{"<pad>", BPE_TOKENIZER_BASE_VOCAB_SIZE + 27}
special_tokens = {
{"<text>", BASE_VOCAB_SIZE}, {"</text>", BASE_VOCAB_SIZE + 1},
{"<url>", BASE_VOCAB_SIZE + 2}, {"</url>", BASE_VOCAB_SIZE + 3},
{"<image>", BASE_VOCAB_SIZE + 4}, {"</image>", BASE_VOCAB_SIZE + 5},
{"<video>", BASE_VOCAB_SIZE + 6}, {"</video>", BASE_VOCAB_SIZE + 7},
{"<audio>", BASE_VOCAB_SIZE + 8}, {"</audio>", BASE_VOCAB_SIZE + 9},
{"<file>", BASE_VOCAB_SIZE + 10}, {"</file>", BASE_VOCAB_SIZE + 11},
{"<code>", BASE_VOCAB_SIZE + 12}, {"</code>", BASE_VOCAB_SIZE + 13},
{"<summary>", BASE_VOCAB_SIZE + 14}, {"</summary>", BASE_VOCAB_SIZE + 15},
{"<think>", BASE_VOCAB_SIZE + 16}, {"</think>", BASE_VOCAB_SIZE + 17},
{"<start>", BASE_VOCAB_SIZE + 18}, {"<end>", BASE_VOCAB_SIZE + 19},
{"<user>", BASE_VOCAB_SIZE + 20}, {"<bot>", BASE_VOCAB_SIZE + 21},
{"<system>", BASE_VOCAB_SIZE + 22}, {"<question>", BASE_VOCAB_SIZE + 23},
{"<answer>", BASE_VOCAB_SIZE + 24}, {"<search>", BASE_VOCAB_SIZE + 25},
{"<unk>", BASE_VOCAB_SIZE + 26}, {"<pad>", BASE_VOCAB_SIZE + 27}
};

// Initialize the vector of special token IDs
Expand All @@ -73,57 +56,99 @@ namespace dlib
// Train the tokenizer on the given text
void train(const std::string& text, int vocab_size, bool verbose = false)
{
DLIB_CASSERT(vocab_size >= BPE_TOKENIZER_BASE_VOCAB_SIZE);
this->vocab_size = vocab_size;
int num_merges = vocab_size - BPE_TOKENIZER_BASE_VOCAB_SIZE;
int current_base = static_cast<int>(BASE_VOCAB_SIZE + special_tokens.size());
DLIB_CASSERT(vocab_size >= current_base);
int num_merges = vocab_size - current_base;
if (num_merges <= 0) return;

// Convert text to byte IDs
std::vector<int> ids;
ids.reserve(text.size());
for (char c : text) ids.push_back(static_cast<uint8_t>(c));

// Perform BPE merges
for (int i = 0; i < num_merges; ++i) {
int n_merges = 0;
for (; n_merges < num_merges; ++n_merges) {
auto stats = get_stats(ids);
if (stats.empty()) break;

// Find the most frequent pair that does not exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH
// Find the most frequent pair that does not exceed MAX_TOKEN_LENGTH
auto pair = get_most_frequent_pair(stats);
if (pair.first == -1) break;

// Check if the resulting token would exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH
// Check if the resulting token would exceed MAX_TOKEN_LENGTH
size_t new_token_length = vocab[pair.first].size() + vocab[pair.second].size();
if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH) {
if (new_token_length > MAX_TOKEN_LENGTH) {
if (verbose)
{
std::cout << "\r"
<< std::setw(100) << std::flush
<< "\rskipping merge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": ("
<< std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> new token length "
<< std::to_string(new_token_length) << " exceeds limit of " << std::to_string(BPE_TOKENIZER_MAX_TOKEN_LENGTH)
<< std::flush;
}
std::cout << "\r" << std::setw(100) << std::flush << "\r[skip] merge " << (n_merges + 1)
<< ": token too long: " << new_token_length << "/" << MAX_TOKEN_LENGTH << std::flush;
continue; // Skip this merge
}

int idx = (BPE_TOKENIZER_BASE_VOCAB_SIZE + (int)special_tokens.size()) + i;
ids = merge(ids, pair, idx);
merges[pair] = idx;
vocab[idx].insert(vocab[idx].end(), vocab[pair.first].begin(), vocab[pair.first].end());
vocab[idx].insert(vocab[idx].end(), vocab[pair.second].begin(), vocab[pair.second].end());
int new_id = current_base + n_merges;
merges[pair] = new_id;

std::vector<uint8_t>& new_token = vocab[new_id];
new_token.reserve(new_token_length);
new_token.insert(new_token.end(), vocab[pair.first].begin(), vocab[pair.first].end());
new_token.insert(new_token.end(), vocab[pair.second].begin(), vocab[pair.second].end());

ids = merge(ids, pair, new_id);

if (verbose)
{
std::cout << "\r"
<< std::setw(100) << std::flush
<< "\rmerge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": ("
<< std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> " << std::to_string(idx)
<< " (" << bytes_to_string(vocab[idx]) << ") had "
<< std::to_string(stats[pair]) << " occurrences"
<< std::endl;
std::cout << "\r" << std::setw(100) << std::flush << "\r[merge] " << (n_merges + 1) << "/" << num_merges
<< ": (" << pair.first << "," << pair.second << ") -> " << new_id
<< " (" << bytes_to_string(vocab[new_id]) << ")" << std::endl;
}
this->vocab_size = current_base + n_merges;
}

// Encode the given text into subword tokens without paragraph splitting or special token wrapping
std::vector<int> encode_raw(const std::string& text) const
{
// Direct encoding without paragraph splitting or special tokens
std::vector<int> ids;
ids.reserve(text.size());

// Convert text to character IDs
for (char c : text) ids.push_back(static_cast<uint8_t>(c));

// Apply BPE merges
auto stats = get_stats(ids);
std::priority_queue<std::pair<int, std::pair<int, int>>> pq;
for (const auto& stat : stats) {
const std::pair<int, int>& pair = stat.first;
if (merges.count(pair)) pq.push({ merges.at(pair), pair });
}

while (!pq.empty()) {
const auto& top_element = pq.top();
const std::pair<int, int>& pair = top_element.second;
pq.pop();

bool pair_found = false;
for (size_t i = 0; i < ids.size() - 1; ++i) {
if (ids[i] == pair.first && ids[i + 1] == pair.second) {
pair_found = true;
break;
}
}
if (!pair_found) continue;

int idx = merges.at(pair);
ids = merge(ids, pair, idx);

stats = get_stats(ids);
for (const auto& stat : stats) {
const std::pair<int, int>& new_pair = stat.first;
if (merges.count(new_pair)) pq.push({ merges.at(new_pair), new_pair });
}
}

return ids;
}

// Encode the given text into subword tokens
// Encode the given text into subword tokens (advanced version)
std::vector<int> encode(const std::string& text) const
{
std::vector<int> result_ids;
Expand Down Expand Up @@ -247,7 +272,7 @@ namespace dlib
// Save the tokenizer model and vocabulary to file
friend void serialize(const bpe_tokenizer& tok, std::ostream& out)
{
serialize("bpe_tokenizer2_", out);
serialize("bpe_tokenizer_", out);
serialize(tok.special_tokens, out);
serialize(tok.special_token_map, out);
serialize(tok.merges, out);
Expand All @@ -259,7 +284,7 @@ namespace dlib
friend void deserialize(bpe_tokenizer& tok, std::istream& in) {
std::string version;
dlib::deserialize(version, in);
if (version != "bpe_tokenizer2_")
if (version != "bpe_tokenizer_")
throw dlib::serialization_error("Unexpected version '" + version + "' found while deserializing dlib::bpe_tokenizer_.");
deserialize(tok.special_tokens, in);
deserialize(tok.special_token_map, in);
Expand Down Expand Up @@ -289,6 +314,9 @@ namespace dlib
std::map<int, std::vector<uint8_t>> vocab;
int vocab_size;

static const size_t MAX_TOKEN_LENGTH = 8;
static const int BASE_VOCAB_SIZE = 256;

// Get frequency statistics of adjacent token pairs
struct pair_hash {
template <class T1, class T2>
Expand Down Expand Up @@ -339,14 +367,16 @@ namespace dlib
// Iterate over all pairs in the statistics map
for (const auto& stat : stats) {
const std::pair<int, int>& pair = stat.first; // Extract the token pair
int count = stat.second; // Extract the frequency count
int frequency = stat.second; // Extract the frequency

// Check if the new token formed by merging the pair would exceed the maximum allowed length
size_t new_token_length = vocab.at(pair.first).size() + vocab.at(pair.second).size();
if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH) continue; // Skip this pair if it exceeds the maximum token length
if (new_token_length > MAX_TOKEN_LENGTH) continue; // Skip this pair if it exceeds the maximum token length

// Calculate the score for this pair (frequency * length_penalty)
double score = (size_t)count * (new_token_length > (BPE_TOKENIZER_MAX_TOKEN_LENGTH / 2) ? 1.75 : 1.0);
double length_bonus = std::min(2.0, 1.0 + (static_cast<double>(new_token_length) - 2.0) * 0.1);
double frequency_weight = std::log1p(frequency);
double score = frequency_weight * length_bonus;

// Update the best pair if the current pair has a higher score
if (score > max_score)
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ add_gui_example(dnn_dcgan_train_ex)
add_gui_example(dnn_yolo_train_ex)
add_gui_example(dnn_self_supervised_learning_ex)
add_example(slm_basic_train_ex)
add_example(slm_advanced_train_ex)
add_gui_example(3d_point_cloud_ex)
add_example(bayes_net_ex)
add_example(bayes_net_from_disk_ex)
Expand Down
Loading
Loading