diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 7ec8b1a956..2b0136ef91 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -2329,24 +2329,46 @@ namespace dlib template < unsigned long num_outputs_, - linear_bias_mode bias_mode_ + linear_bias_mode bias_mode_ = LINEAR_HAS_BIAS > class linear_ { static_assert(num_outputs_ > 0, "The number of outputs from a linear_ layer must be > 0"); public: - linear_() : + explicit linear_() : num_outputs(num_outputs_), - num_inputs(0), + num_inputs(0), learning_rate_multiplier(1), bias_mode(bias_mode_) { } + linear_(const linear_& other) : + num_outputs(other.num_outputs), + num_inputs(other.num_inputs), + learning_rate_multiplier(other.learning_rate_multiplier), + bias_mode(other.bias_mode), + params(other.params), + weights(other.weights), + biases(other.biases) { + } + + linear_& operator=(const linear_& other) { + if (this != &other) { + num_outputs = other.num_outputs; + num_inputs = other.num_inputs; + learning_rate_multiplier = other.learning_rate_multiplier; + bias_mode = other.bias_mode; + params = other.params; + weights = other.weights; + biases = other.biases; + } + return *this; + } + double get_learning_rate_multiplier() const { return learning_rate_multiplier; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } - - unsigned long get_num_inputs() const { return num_inputs; } + unsigned long get_num_outputs() const { return num_outputs; } void set_num_outputs(long num) { @@ -2358,6 +2380,7 @@ namespace dlib num_outputs = num; } } + unsigned long get_num_inputs() const { return num_inputs; } linear_bias_mode get_bias_mode() const { return bias_mode; } template @@ -2503,8 +2526,8 @@ namespace dlib } private: - unsigned long num_inputs; unsigned long num_outputs; + unsigned long num_inputs; double learning_rate_multiplier; linear_bias_mode bias_mode; resizable_tensor params; @@ -2515,7 +2538,7 @@ namespace dlib unsigned long num_outputs, typename SUBNET > - using linear = add_layer, SUBNET>; + using linear = add_layer, SUBNET>; template < unsigned long num_outputs, diff --git a/dlib/tokenizer/bpe_tokenizer.h b/dlib/tokenizer/bpe_tokenizer.h index f9457b554f..642f7c760b 100644 --- a/dlib/tokenizer/bpe_tokenizer.h +++ b/dlib/tokenizer/bpe_tokenizer.h @@ -20,49 +20,32 @@ namespace dlib { - constexpr size_t BPE_TOKENIZER_MAX_TOKEN_LENGTH = 8; - constexpr int BPE_TOKENIZER_BASE_VOCAB_SIZE = 256; class bpe_tokenizer { public: - bpe_tokenizer() : vocab_size(BPE_TOKENIZER_BASE_VOCAB_SIZE) + bpe_tokenizer() : vocab_size(BASE_VOCAB_SIZE) { // Initialize the base vocabulary with single bytes - for (int i = 0; i < BPE_TOKENIZER_BASE_VOCAB_SIZE; ++i) + for (int i = 0; i < BASE_VOCAB_SIZE; ++i) vocab[i] = std::vector{ static_cast(i) }; // Initialize special tokens with sequential IDs - special_tokens = - { - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 1}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 2}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 3}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 4}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 5}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 7}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 9}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 10}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 11}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 12}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 13}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 14}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 15}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 16}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 17}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 18}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 19}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 20}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 21}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 22}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 23}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 24}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 25}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 26}, - {"", BPE_TOKENIZER_BASE_VOCAB_SIZE + 27} + special_tokens = { + {"", BASE_VOCAB_SIZE}, {"", BASE_VOCAB_SIZE + 1}, + {"", BASE_VOCAB_SIZE + 2}, {"", BASE_VOCAB_SIZE + 3}, + {"", BASE_VOCAB_SIZE + 4}, {"", BASE_VOCAB_SIZE + 5}, + {"", BASE_VOCAB_SIZE + 7}, + {"", BASE_VOCAB_SIZE + 9}, + {"", BASE_VOCAB_SIZE + 10}, {"", BASE_VOCAB_SIZE + 11}, + {"", BASE_VOCAB_SIZE + 12}, {"", BASE_VOCAB_SIZE + 13}, + {"", BASE_VOCAB_SIZE + 14}, {"", BASE_VOCAB_SIZE + 15}, + {"", BASE_VOCAB_SIZE + 16}, {"", BASE_VOCAB_SIZE + 17}, + {"", BASE_VOCAB_SIZE + 18}, {"", BASE_VOCAB_SIZE + 19}, + {"", BASE_VOCAB_SIZE + 20}, {"", BASE_VOCAB_SIZE + 21}, + {"", BASE_VOCAB_SIZE + 22}, {"", BASE_VOCAB_SIZE + 23}, + {"", BASE_VOCAB_SIZE + 24}, {"", BASE_VOCAB_SIZE + 25}, + {"", BASE_VOCAB_SIZE + 26}, {"", BASE_VOCAB_SIZE + 27} }; // Initialize the vector of special token IDs @@ -73,57 +56,99 @@ namespace dlib // Train the tokenizer on the given text void train(const std::string& text, int vocab_size, bool verbose = false) { - DLIB_CASSERT(vocab_size >= BPE_TOKENIZER_BASE_VOCAB_SIZE); - this->vocab_size = vocab_size; - int num_merges = vocab_size - BPE_TOKENIZER_BASE_VOCAB_SIZE; + int current_base = static_cast(BASE_VOCAB_SIZE + special_tokens.size()); + DLIB_CASSERT(vocab_size >= current_base); + int num_merges = vocab_size - current_base; + if (num_merges <= 0) return; // Convert text to byte IDs std::vector ids; + ids.reserve(text.size()); for (char c : text) ids.push_back(static_cast(c)); // Perform BPE merges - for (int i = 0; i < num_merges; ++i) { + int n_merges = 0; + for (; n_merges < num_merges; ++n_merges) { auto stats = get_stats(ids); if (stats.empty()) break; - // Find the most frequent pair that does not exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH + // Find the most frequent pair that does not exceed MAX_TOKEN_LENGTH auto pair = get_most_frequent_pair(stats); + if (pair.first == -1) break; - // Check if the resulting token would exceed BPE_TOKENIZER_MAX_TOKEN_LENGTH + // Check if the resulting token would exceed MAX_TOKEN_LENGTH size_t new_token_length = vocab[pair.first].size() + vocab[pair.second].size(); - if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH) { + if (new_token_length > MAX_TOKEN_LENGTH) { if (verbose) - { - std::cout << "\r" - << std::setw(100) << std::flush - << "\rskipping merge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": (" - << std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> new token length " - << std::to_string(new_token_length) << " exceeds limit of " << std::to_string(BPE_TOKENIZER_MAX_TOKEN_LENGTH) - << std::flush; - } + std::cout << "\r" << std::setw(100) << std::flush << "\r[skip] merge " << (n_merges + 1) + << ": token too long: " << new_token_length << "/" << MAX_TOKEN_LENGTH << std::flush; continue; // Skip this merge } - int idx = (BPE_TOKENIZER_BASE_VOCAB_SIZE + (int)special_tokens.size()) + i; - ids = merge(ids, pair, idx); - merges[pair] = idx; - vocab[idx].insert(vocab[idx].end(), vocab[pair.first].begin(), vocab[pair.first].end()); - vocab[idx].insert(vocab[idx].end(), vocab[pair.second].begin(), vocab[pair.second].end()); + int new_id = current_base + n_merges; + merges[pair] = new_id; + + std::vector& new_token = vocab[new_id]; + new_token.reserve(new_token_length); + new_token.insert(new_token.end(), vocab[pair.first].begin(), vocab[pair.first].end()); + new_token.insert(new_token.end(), vocab[pair.second].begin(), vocab[pair.second].end()); + + ids = merge(ids, pair, new_id); if (verbose) - { - std::cout << "\r" - << std::setw(100) << std::flush - << "\rmerge " << std::to_string(i + 1) << "/" << std::to_string(num_merges) << ": (" - << std::to_string(pair.first) << "," << std::to_string(pair.second) << ") -> " << std::to_string(idx) - << " (" << bytes_to_string(vocab[idx]) << ") had " - << std::to_string(stats[pair]) << " occurrences" - << std::endl; + std::cout << "\r" << std::setw(100) << std::flush << "\r[merge] " << (n_merges + 1) << "/" << num_merges + << ": (" << pair.first << "," << pair.second << ") -> " << new_id + << " (" << bytes_to_string(vocab[new_id]) << ")" << std::endl; + } + this->vocab_size = current_base + n_merges; + } + + // Encode the given text into subword tokens without paragraph splitting or special token wrapping + std::vector encode_raw(const std::string& text) const + { + // Direct encoding without paragraph splitting or special tokens + std::vector ids; + ids.reserve(text.size()); + + // Convert text to character IDs + for (char c : text) ids.push_back(static_cast(c)); + + // Apply BPE merges + auto stats = get_stats(ids); + std::priority_queue>> pq; + for (const auto& stat : stats) { + const std::pair& pair = stat.first; + if (merges.count(pair)) pq.push({ merges.at(pair), pair }); + } + + while (!pq.empty()) { + const auto& top_element = pq.top(); + const std::pair& pair = top_element.second; + pq.pop(); + + bool pair_found = false; + for (size_t i = 0; i < ids.size() - 1; ++i) { + if (ids[i] == pair.first && ids[i + 1] == pair.second) { + pair_found = true; + break; + } + } + if (!pair_found) continue; + + int idx = merges.at(pair); + ids = merge(ids, pair, idx); + + stats = get_stats(ids); + for (const auto& stat : stats) { + const std::pair& new_pair = stat.first; + if (merges.count(new_pair)) pq.push({ merges.at(new_pair), new_pair }); } } + + return ids; } - // Encode the given text into subword tokens + // Encode the given text into subword tokens (advanced version) std::vector encode(const std::string& text) const { std::vector result_ids; @@ -247,7 +272,7 @@ namespace dlib // Save the tokenizer model and vocabulary to file friend void serialize(const bpe_tokenizer& tok, std::ostream& out) { - serialize("bpe_tokenizer2_", out); + serialize("bpe_tokenizer_", out); serialize(tok.special_tokens, out); serialize(tok.special_token_map, out); serialize(tok.merges, out); @@ -259,7 +284,7 @@ namespace dlib friend void deserialize(bpe_tokenizer& tok, std::istream& in) { std::string version; dlib::deserialize(version, in); - if (version != "bpe_tokenizer2_") + if (version != "bpe_tokenizer_") throw dlib::serialization_error("Unexpected version '" + version + "' found while deserializing dlib::bpe_tokenizer_."); deserialize(tok.special_tokens, in); deserialize(tok.special_token_map, in); @@ -289,6 +314,9 @@ namespace dlib std::map> vocab; int vocab_size; + static const size_t MAX_TOKEN_LENGTH = 8; + static const int BASE_VOCAB_SIZE = 256; + // Get frequency statistics of adjacent token pairs struct pair_hash { template @@ -339,14 +367,16 @@ namespace dlib // Iterate over all pairs in the statistics map for (const auto& stat : stats) { const std::pair& pair = stat.first; // Extract the token pair - int count = stat.second; // Extract the frequency count + int frequency = stat.second; // Extract the frequency // Check if the new token formed by merging the pair would exceed the maximum allowed length size_t new_token_length = vocab.at(pair.first).size() + vocab.at(pair.second).size(); - if (new_token_length > BPE_TOKENIZER_MAX_TOKEN_LENGTH) continue; // Skip this pair if it exceeds the maximum token length + if (new_token_length > MAX_TOKEN_LENGTH) continue; // Skip this pair if it exceeds the maximum token length // Calculate the score for this pair (frequency * length_penalty) - double score = (size_t)count * (new_token_length > (BPE_TOKENIZER_MAX_TOKEN_LENGTH / 2) ? 1.75 : 1.0); + double length_bonus = std::min(2.0, 1.0 + (static_cast(new_token_length) - 2.0) * 0.1); + double frequency_weight = std::log1p(frequency); + double score = frequency_weight * length_bonus; // Update the best pair if the current pair has a higher score if (score > max_score) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index c23067879a..1232d58b09 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -147,6 +147,7 @@ add_gui_example(dnn_dcgan_train_ex) add_gui_example(dnn_yolo_train_ex) add_gui_example(dnn_self_supervised_learning_ex) add_example(slm_basic_train_ex) +add_example(slm_advanced_train_ex) add_gui_example(3d_point_cloud_ex) add_example(bayes_net_ex) add_example(bayes_net_from_disk_ex) diff --git a/examples/slm_advanced_train_ex.cpp b/examples/slm_advanced_train_ex.cpp new file mode 100644 index 0000000000..b655b13ba1 --- /dev/null +++ b/examples/slm_advanced_train_ex.cpp @@ -0,0 +1,1305 @@ +/*! + @file slm_advanced_train_ex.cpp + @brief Transformer-based text training/generation + + This program implements a complete training and generation pipeline for a + Transformer-based text compression system. + The model features: + + 1. Rotary Positional Embeddings (RoPE) for enhanced positional encoding + 2. Multi-head self-attention with efficient memory handling + 3. Mixture-of-Experts architecture for specialized processing + 4. BPE tokenization with custom vocabulary + 5. Full training/generation/verification workflow + + Key capabilities demonstrated: + - Perfect memorization and reproduction of training text + - Efficient autoregressive generation + - Byte-level verification of reconstructed text + + References: + [1] Vaswani et al., "Attention Is All You Need" (Transformer architecture) + arXiv:1706.03762 + [2] Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" + arXiv:2104.09864 + [3] Shazeer et al., "Outrageously Large Neural Networks: The Sparsely-Gated + Mixture-of-Experts Layer" (MoE architecture) arXiv:1701.06538 + + Usage modes: + --train Train model on enwiki dataset + --generate Generate text from trained model + --verify Compare generated output with original + --tokenize-only Only perform tokenization step + + Configuration: + - Adjust template parameters in transformer_config for model architecture + - Modify training parameters in main() for optimization + - Set sequence length and memory limits according to available hardware +!*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace dlib; + +namespace dlib +{ + /*! + @class rotary_positional_embedding_ + @brief Implements Rotary Positional Embeddings (RoPE) for transformers + + This layer applies rotary positional embeddings to queries and keys in + self-attention layers, providing relative positional information without + absolute position embeddings. + + The implementation follows the RoPE formulation from [2], where positions + are encoded through rotation matrices applied to pairs of dimensions. + !*/ + class rotary_positional_embedding_ { + public: + explicit rotary_positional_embedding_() = default; + + template + void setup(const SUBNET& sub) { + // Precompute the rotation angles and their trigonometric values + seq_len = sub.get_output().nr(); + d_head = sub.get_output().nc(); + compute_rotation_angles(); + precompute_trigonometric_values(); + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) { + const tensor& input = sub.get_output(); + output.copy_size(input); + tt::copy_tensor(false, output, 0, input, 0, input.k()); + + // Apply rotary embedding to the output + apply_rotary_embedding(output); + } + + template + void backward( + const tensor& gradient_input, + SUBNET& sub, + tensor& params_grad + ) { + tensor& prev = sub.get_gradient_input(); + resizable_tensor grad_output; + grad_output.copy_size(gradient_input); + tt::copy_tensor(false, grad_output, 0, gradient_input, 0, gradient_input.k()); + + // Apply the inverse rotation to the gradient (transpose of the rotation matrix) + apply_rotary_embedding(grad_output, true); + tt::copy_tensor(true, prev, 0, grad_output, 0, grad_output.k()); + } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const rotary_positional_embedding_& item, std::ostream& out) { + std::string version = "rotary_positional_embedding_"; + dlib::serialize(version, out); + dlib::serialize(item.seq_len, out); + dlib::serialize(item.d_head, out); + dlib::serialize(item.angles, out); + dlib::serialize(item.cos_values, out); + dlib::serialize(item.sin_values, out); + } + + friend void deserialize(rotary_positional_embedding_& item, std::istream& in) { + std::string version; + dlib::deserialize(version, in); + if (version != "rotary_positional_embedding_") + throw serialization_error("Unexpected version found while deserializing rotary_positional_embedding_."); + dlib::deserialize(item.seq_len, in); + dlib::deserialize(item.d_head, in); + dlib::deserialize(item.angles, in); + dlib::deserialize(item.cos_values, in); + dlib::deserialize(item.sin_values, in); + } + + friend std::ostream& operator<<(std::ostream& out, const rotary_positional_embedding_& item) { + out << "rotary_positional_embedding"; + out << " (d_head=" << item.d_head << ", seq_len=" << item.seq_len << ")"; + return out; + } + + friend void to_xml(const rotary_positional_embedding_& item, std::ostream& out) + { + out << "\n"; + } + + protected: + void compute_rotation_angles() { + // Following the original RoPE paper formulation + const float base = 10000.0f; + const long half_dim = d_head / 2; + angles.set_size(seq_len, half_dim); + + for (long pos = 0; pos < seq_len; ++pos) { + for (long i = 0; i < half_dim; ++i) { + float inv_freq = std::pow(base, -2.0f * (i + 0.5f) / d_head); + angles(pos, i) = pos * inv_freq; + } + } + } + + void precompute_trigonometric_values() { + // Precompute cos and sin for all angles + cos_values.set_size(angles.nr(), angles.nc()); + sin_values.set_size(angles.nr(), angles.nc()); + + for (long i = 0; i < angles.size(); ++i) { + cos_values(i) = std::cos(angles(i)); + sin_values(i) = std::sin(angles(i)); + } + } + + template + void apply_rotary_embedding( + tensor_type& x, + bool is_backward = false + ) const { + const long batch_size = x.num_samples(); + const long num_heads = x.k(); + const long seq_length = x.nr(); + const long dim = x.nc(); + const bool is_odd = (dim % 2 != 0); + const long rot_dim = is_odd ? dim - 1 : dim; + + DLIB_CASSERT(dim == d_head, "Input dimension must match d_head param"); + DLIB_CASSERT(seq_length == seq_len, "Sequence length must match seq_len param"); + + auto* ptr = x.host(); + const long stride = seq_length * dim; + + for (long n = 0; n < batch_size; ++n) { + for (long h = 0; h < num_heads; ++h) { + auto* x_ptr = ptr + (n * num_heads + h) * stride; + + for (long pos = 0; pos < seq_length; ++pos) { + const float* cos = &cos_values(pos, 0); + const float* sin = &sin_values(pos, 0); + + for (long i = 0; i < rot_dim; i += 2) { + const float x0 = x_ptr[pos * dim + i]; + const float x1 = x_ptr[pos * dim + i + 1]; + + if (!is_backward) { + x_ptr[pos * dim + i] = x0 * cos[i / 2] - x1 * sin[i / 2]; + x_ptr[pos * dim + i + 1] = x0 * sin[i / 2] + x1 * cos[i / 2]; + } + else { + x_ptr[pos * dim + i] = x0 * cos[i / 2] + x1 * sin[i / 2]; + x_ptr[pos * dim + i + 1] = -x0 * sin[i / 2] + x1 * cos[i / 2]; + } + } + } + } + } + } + + private: + long seq_len, d_head; // Sequence length and dimension of each head + matrix angles; // Precomputed rotation angles (seq_len x d_head/2) + matrix cos_values; // Precomputed cosine values + matrix sin_values; // Precomputed sine values + resizable_tensor params; // Empty tensor (no learnable parameters) + }; + + // Helper to easily add RoPE to a network + template + using rope = add_layer; + + template + class scale_weights_ : public multiply_ { + public: + explicit scale_weights_() : multiply_(1.0f / std::sqrt(static_cast(d_k_))) {} + }; + + template + using scale_weights = add_layer, SUBNET>; + + // Attention mechanism component extractors + template + using query = reshape_to>; + + template + using key = reshape_to>; + + template + using value = reshape_to>; + + /*! + This layer implements multi-head self-attention. + + Template parameters: + - ACT: Activation function type + - DO: Dropout layer type for regularization + - d_model: Model dimension (must be divisible by num_heads) + - num_heads: Number of attention heads + !*/ + template