Skip to content

Commit d4bf94b

Browse files
committed
finalizing the example
1 parent 10d7c59 commit d4bf94b

File tree

1 file changed

+45
-32
lines changed

1 file changed

+45
-32
lines changed

examples/slm_advanced_train_ex.cpp

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,19 @@
5656
using namespace std;
5757
using namespace dlib;
5858

59-
namespace ernie
59+
namespace dlib
6060
{
61+
/*!
62+
@class rotary_positional_embedding_
63+
@brief Implements Rotary Positional Embeddings (RoPE) for transformers
64+
65+
This layer applies rotary positional embeddings to queries and keys in
66+
self-attention layers, providing relative positional information without
67+
absolute position embeddings.
68+
69+
The implementation follows the RoPE formulation from [2], where positions
70+
are encoded through rotation matrices applied to pairs of dimensions.
71+
!*/
6172
class rotary_positional_embedding_ {
6273
public:
6374
explicit rotary_positional_embedding_() = default;
@@ -386,7 +397,7 @@ namespace ernie
386397
struct model_info {
387398
static std::string describe() {
388399
std::stringstream ss;
389-
ss << "ERNIE Transformer model configuration:\n"
400+
ss << "Transformer model configuration:\n"
390401
<< "- vocabulary size: " << VOCAB_SIZE << "\n"
391402
<< "- layers: " << NUM_LAYERS << "\n"
392403
<< "- attention heads: " << NUM_HEADS << "\n"
@@ -674,9 +685,9 @@ int main(int argc, char** argv)
674685
command_line_parser parser;
675686
parser.add_option("train", "Train a transformer model on enwiki");
676687
parser.add_option("generate", "Generate enwiki from a previously trained model");
677-
parser.add_option("verify", "Verify generated output against original enwiki");
688+
parser.add_option("verify", "Verify generated output against original data");
678689
parser.add_option("tokenize-only", "Only tokenize the input file and save tokens");
679-
parser.add_option("enwiki", "Path to the enwiki file", 1);
690+
parser.add_option("enwiki", "Path to the enwiki file (default: enwiki.txt)", 1);
680691
parser.add_option("max-tokens", "Maximum number of tokens to load in memory", 1);
681692
parser.add_option("max-bytes", "Maximum number of bytes to process from enwiki", 1);
682693
parser.add_option("percent", "Percentage of enwiki to process (0-100)", 1);
@@ -687,9 +698,9 @@ int main(int argc, char** argv)
687698
parser.add_option("alpha", "Set the weight decay for Adam (default: 0.004)", 1);
688699
parser.add_option("beta1", "Set Adam's first moment coefficient (default: 0.9)", 1);
689700
parser.add_option("beta2", "Set Adam's second moment coefficient (default: 0.999)", 1);
690-
parser.add_option("model-file", "Path for model (default: ernie_model.dat)", 1);
701+
parser.add_option("model-file", "Path for model (default: slm_enwiki_model.dat)", 1);
691702
parser.add_option("output-file", "Path for output (default: enwiki_generated.txt)", 1);
692-
parser.add_option("tokenizer", "Path to pre-trained tokenizer (default: ernie_tokenizer.vocab)", 1);
703+
parser.add_option("tokenizer", "Path to pre-trained tokenizer (default: enwiki_tokenizer.vocab)", 1);
693704
parser.add_option("tokens-file", "Path to pre-tokenized tokens file (optional)", 1);
694705
parser.add_option("force-tokenize", "Force tokenization even if tokens file exists");
695706
parser.parse(argc, argv);
@@ -710,14 +721,14 @@ int main(int argc, char** argv)
710721
const double alpha = get_option(parser, "alpha", 0.004);
711722
const double beta1 = get_option(parser, "beta1", 0.9);
712723
const double beta2 = get_option(parser, "beta2", 0.999);
713-
const std::string model_file = get_option(parser, "model-file", "ernie_model.dat");
724+
const std::string model_file = get_option(parser, "model-file", "slm_enwiki_model.dat");
714725
const std::string output_file = get_option(parser, "output-file", "enwiki_generated.txt");
715-
const std::string enwiki_path = get_option(parser, "enwiki", "enwiki");
726+
const std::string enwiki_path = get_option(parser, "enwiki", "enwiki.txt");
716727
const long max_seq_len = 180;
717728
const long num_layers = 2;
718729
const long num_heads = 6;
719730
const long embedding_dim = 228;
720-
const std::string tokenizer_path = get_option(parser, "tokenizer", "ernie_tokenizer.vocab");
731+
const std::string tokenizer_path = get_option(parser, "tokenizer", "enwiki_tokenizer.vocab");
721732
// Default number of prompt tokens = input sequence length
722733
const bool force_tokenize = parser.option("force-tokenize");
723734
const long num_tokens = 1000;
@@ -760,7 +771,7 @@ int main(int argc, char** argv)
760771
parser.option("tokens-file").argument() :
761772
generate_tokens_filename(enwiki_path, max_bytes);
762773

763-
using ernie_transformer = ernie::transformer_config<
774+
using enwiki_transformer = transformer_config<
764775
num_tokens, // vocab_size
765776
num_layers, // number of layers
766777
num_heads, // number of attention heads
@@ -945,9 +956,9 @@ int main(int argc, char** argv)
945956
cout << "Created " << samples.size() << " training samples (100%)...\n";
946957

947958
// 5) Build and train the network
948-
using net_type = ernie_transformer::network_type<true>;
959+
using net_type = enwiki_transformer::network_type<true>;
949960
net_type net;
950-
cout << "Model architecture:\n" << ernie_transformer::model_info::describe() << endl;
961+
cout << "Model architecture:\n" << enwiki_transformer::model_info::describe() << endl;
951962
if (file_exists(model_file)) deserialize(model_file) >> net;
952963

953964
// Create trainer
@@ -958,7 +969,7 @@ int main(int argc, char** argv)
958969
// For perfect memorization, we allow more epochs without improvement
959970
trainer.set_iterations_without_progress_threshold(patience);
960971
trainer.set_max_num_epochs(max_epochs); // More epochs for perfect memorization
961-
trainer.set_synchronization_file("ernie_trainer.sync", std::chrono::minutes(10));
972+
trainer.set_synchronization_file("enwiki_trainer.sync", std::chrono::minutes(10));
962973
trainer.be_quiet();
963974

964975
// Custom training loop - trainer.train(samples, labels)
@@ -1027,27 +1038,29 @@ int main(int argc, char** argv)
10271038
net.clean();
10281039
serialize(model_file) << net;
10291040
cout << "Model saved to " << model_file << "\n";
1030-
std::remove("ernie_trainer.sync");
1031-
std::remove("ernie_trainer.sync_");
1041+
std::remove("enwiki_trainer.sync");
1042+
std::remove("enwiki_trainer.sync_");
10321043

10331044
// Evaluate on training set
1034-
if (!g_terminate_flag.load()) {
1035-
cout << "Evaluating model accuracy...\n";
1036-
using net_infer = ernie_transformer::network_type<false>;
1037-
net_infer g_infer = net;
1038-
auto predicted = g_infer(samples);
1039-
size_t correct = 0;
1040-
for (size_t i = 0; i < labels.size(); ++i)
1041-
if (predicted[i] == labels[i]) correct++;
1042-
double accuracy = (double)correct / labels.size();
1043-
cout << "Training accuracy: " << (accuracy * 100.0) << "%\n";
1044-
1045-
// We need perfect accuracy to reconstruct enwiki
1046-
if (accuracy < 0.9999) {
1047-
cout << "WARNING: Model accuracy is less than 99.99%. The model may not "
1048-
<< "perfectly reconstruct the input text.\n";
1045+
{
1046+
if (!g_terminate_flag.load()) {
1047+
cout << "Evaluating model accuracy...\n";
1048+
using net_infer = enwiki_transformer::network_type<false>;
1049+
net_infer g_infer = net;
1050+
auto predicted = g_infer(samples);
1051+
size_t correct = 0;
1052+
for (size_t i = 0; i < labels.size(); ++i)
1053+
if (predicted[i] == labels[i]) correct++;
1054+
double accuracy = (double)correct / labels.size();
1055+
cout << "Training accuracy: " << (accuracy * 100.0) << "%\n";
1056+
1057+
// We need perfect accuracy to reconstruct enwiki
1058+
if (accuracy < 0.9999) {
1059+
cout << "WARNING: Model accuracy is less than 99.99%. The model may not "
1060+
<< "perfectly reconstruct the input text.\n";
1061+
}
10491062
}
1050-
}
1063+
}
10511064
}
10521065

10531066
// ----------------------------------------------------------------------------------------
@@ -1058,7 +1071,7 @@ int main(int argc, char** argv)
10581071
cout << "=== GENERATION MODE ===\n";
10591072

10601073
// 1) Load the model
1061-
using net_infer = ernie_transformer::network_type<false>;
1074+
using net_infer = enwiki_transformer::network_type<false>;
10621075
net_infer net;
10631076
if (file_exists(model_file)) {
10641077
deserialize(model_file) >> net;

0 commit comments

Comments
 (0)