56
56
using namespace std ;
57
57
using namespace dlib ;
58
58
59
- namespace ernie
59
+ namespace dlib
60
60
{
61
+ /* !
62
+ @class rotary_positional_embedding_
63
+ @brief Implements Rotary Positional Embeddings (RoPE) for transformers
64
+
65
+ This layer applies rotary positional embeddings to queries and keys in
66
+ self-attention layers, providing relative positional information without
67
+ absolute position embeddings.
68
+
69
+ The implementation follows the RoPE formulation from [2], where positions
70
+ are encoded through rotation matrices applied to pairs of dimensions.
71
+ !*/
61
72
class rotary_positional_embedding_ {
62
73
public:
63
74
explicit rotary_positional_embedding_ () = default;
@@ -386,7 +397,7 @@ namespace ernie
386
397
struct model_info {
387
398
static std::string describe () {
388
399
std::stringstream ss;
389
- ss << " ERNIE Transformer model configuration:\n "
400
+ ss << " Transformer model configuration:\n "
390
401
<< " - vocabulary size: " << VOCAB_SIZE << " \n "
391
402
<< " - layers: " << NUM_LAYERS << " \n "
392
403
<< " - attention heads: " << NUM_HEADS << " \n "
@@ -674,9 +685,9 @@ int main(int argc, char** argv)
674
685
command_line_parser parser;
675
686
parser.add_option (" train" , " Train a transformer model on enwiki" );
676
687
parser.add_option (" generate" , " Generate enwiki from a previously trained model" );
677
- parser.add_option (" verify" , " Verify generated output against original enwiki " );
688
+ parser.add_option (" verify" , " Verify generated output against original data " );
678
689
parser.add_option (" tokenize-only" , " Only tokenize the input file and save tokens" );
679
- parser.add_option (" enwiki" , " Path to the enwiki file" , 1 );
690
+ parser.add_option (" enwiki" , " Path to the enwiki file (default: enwiki.txt) " , 1 );
680
691
parser.add_option (" max-tokens" , " Maximum number of tokens to load in memory" , 1 );
681
692
parser.add_option (" max-bytes" , " Maximum number of bytes to process from enwiki" , 1 );
682
693
parser.add_option (" percent" , " Percentage of enwiki to process (0-100)" , 1 );
@@ -687,9 +698,9 @@ int main(int argc, char** argv)
687
698
parser.add_option (" alpha" , " Set the weight decay for Adam (default: 0.004)" , 1 );
688
699
parser.add_option (" beta1" , " Set Adam's first moment coefficient (default: 0.9)" , 1 );
689
700
parser.add_option (" beta2" , " Set Adam's second moment coefficient (default: 0.999)" , 1 );
690
- parser.add_option (" model-file" , " Path for model (default: ernie_model .dat)" , 1 );
701
+ parser.add_option (" model-file" , " Path for model (default: slm_enwiki_model .dat)" , 1 );
691
702
parser.add_option (" output-file" , " Path for output (default: enwiki_generated.txt)" , 1 );
692
- parser.add_option (" tokenizer" , " Path to pre-trained tokenizer (default: ernie_tokenizer .vocab)" , 1 );
703
+ parser.add_option (" tokenizer" , " Path to pre-trained tokenizer (default: enwiki_tokenizer .vocab)" , 1 );
693
704
parser.add_option (" tokens-file" , " Path to pre-tokenized tokens file (optional)" , 1 );
694
705
parser.add_option (" force-tokenize" , " Force tokenization even if tokens file exists" );
695
706
parser.parse (argc, argv);
@@ -710,14 +721,14 @@ int main(int argc, char** argv)
710
721
const double alpha = get_option (parser, " alpha" , 0.004 );
711
722
const double beta1 = get_option (parser, " beta1" , 0.9 );
712
723
const double beta2 = get_option (parser, " beta2" , 0.999 );
713
- const std::string model_file = get_option (parser, " model-file" , " ernie_model .dat" );
724
+ const std::string model_file = get_option (parser, " model-file" , " slm_enwiki_model .dat" );
714
725
const std::string output_file = get_option (parser, " output-file" , " enwiki_generated.txt" );
715
- const std::string enwiki_path = get_option (parser, " enwiki" , " enwiki" );
726
+ const std::string enwiki_path = get_option (parser, " enwiki" , " enwiki.txt " );
716
727
const long max_seq_len = 180 ;
717
728
const long num_layers = 2 ;
718
729
const long num_heads = 6 ;
719
730
const long embedding_dim = 228 ;
720
- const std::string tokenizer_path = get_option (parser, " tokenizer" , " ernie_tokenizer .vocab" );
731
+ const std::string tokenizer_path = get_option (parser, " tokenizer" , " enwiki_tokenizer .vocab" );
721
732
// Default number of prompt tokens = input sequence length
722
733
const bool force_tokenize = parser.option (" force-tokenize" );
723
734
const long num_tokens = 1000 ;
@@ -760,7 +771,7 @@ int main(int argc, char** argv)
760
771
parser.option (" tokens-file" ).argument () :
761
772
generate_tokens_filename (enwiki_path, max_bytes);
762
773
763
- using ernie_transformer = ernie:: transformer_config<
774
+ using enwiki_transformer = transformer_config<
764
775
num_tokens, // vocab_size
765
776
num_layers, // number of layers
766
777
num_heads, // number of attention heads
@@ -945,9 +956,9 @@ int main(int argc, char** argv)
945
956
cout << " Created " << samples.size () << " training samples (100%)...\n " ;
946
957
947
958
// 5) Build and train the network
948
- using net_type = ernie_transformer ::network_type<true >;
959
+ using net_type = enwiki_transformer ::network_type<true >;
949
960
net_type net;
950
- cout << " Model architecture:\n " << ernie_transformer ::model_info::describe () << endl;
961
+ cout << " Model architecture:\n " << enwiki_transformer ::model_info::describe () << endl;
951
962
if (file_exists (model_file)) deserialize (model_file) >> net;
952
963
953
964
// Create trainer
@@ -958,7 +969,7 @@ int main(int argc, char** argv)
958
969
// For perfect memorization, we allow more epochs without improvement
959
970
trainer.set_iterations_without_progress_threshold (patience);
960
971
trainer.set_max_num_epochs (max_epochs); // More epochs for perfect memorization
961
- trainer.set_synchronization_file (" ernie_trainer .sync" , std::chrono::minutes (10 ));
972
+ trainer.set_synchronization_file (" enwiki_trainer .sync" , std::chrono::minutes (10 ));
962
973
trainer.be_quiet ();
963
974
964
975
// Custom training loop - trainer.train(samples, labels)
@@ -1027,27 +1038,29 @@ int main(int argc, char** argv)
1027
1038
net.clean ();
1028
1039
serialize (model_file) << net;
1029
1040
cout << " Model saved to " << model_file << " \n " ;
1030
- std::remove (" ernie_trainer .sync" );
1031
- std::remove (" ernie_trainer .sync_" );
1041
+ std::remove (" enwiki_trainer .sync" );
1042
+ std::remove (" enwiki_trainer .sync_" );
1032
1043
1033
1044
// Evaluate on training set
1034
- if (!g_terminate_flag.load ()) {
1035
- cout << " Evaluating model accuracy...\n " ;
1036
- using net_infer = ernie_transformer::network_type<false >;
1037
- net_infer g_infer = net;
1038
- auto predicted = g_infer (samples);
1039
- size_t correct = 0 ;
1040
- for (size_t i = 0 ; i < labels.size (); ++i)
1041
- if (predicted[i] == labels[i]) correct++;
1042
- double accuracy = (double )correct / labels.size ();
1043
- cout << " Training accuracy: " << (accuracy * 100.0 ) << " %\n " ;
1044
-
1045
- // We need perfect accuracy to reconstruct enwiki
1046
- if (accuracy < 0.9999 ) {
1047
- cout << " WARNING: Model accuracy is less than 99.99%. The model may not "
1048
- << " perfectly reconstruct the input text.\n " ;
1045
+ {
1046
+ if (!g_terminate_flag.load ()) {
1047
+ cout << " Evaluating model accuracy...\n " ;
1048
+ using net_infer = enwiki_transformer::network_type<false >;
1049
+ net_infer g_infer = net;
1050
+ auto predicted = g_infer (samples);
1051
+ size_t correct = 0 ;
1052
+ for (size_t i = 0 ; i < labels.size (); ++i)
1053
+ if (predicted[i] == labels[i]) correct++;
1054
+ double accuracy = (double )correct / labels.size ();
1055
+ cout << " Training accuracy: " << (accuracy * 100.0 ) << " %\n " ;
1056
+
1057
+ // We need perfect accuracy to reconstruct enwiki
1058
+ if (accuracy < 0.9999 ) {
1059
+ cout << " WARNING: Model accuracy is less than 99.99%. The model may not "
1060
+ << " perfectly reconstruct the input text.\n " ;
1061
+ }
1049
1062
}
1050
- }
1063
+ }
1051
1064
}
1052
1065
1053
1066
// ----------------------------------------------------------------------------------------
@@ -1058,7 +1071,7 @@ int main(int argc, char** argv)
1058
1071
cout << " === GENERATION MODE ===\n " ;
1059
1072
1060
1073
// 1) Load the model
1061
- using net_infer = ernie_transformer ::network_type<false >;
1074
+ using net_infer = enwiki_transformer ::network_type<false >;
1062
1075
net_infer net;
1063
1076
if (file_exists (model_file)) {
1064
1077
deserialize (model_file) >> net;
0 commit comments