@@ -112,7 +112,7 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTo
112
112
/// <param name="cacheSize">The size of the cache to use.</param>
113
113
/// <param name="normalizer">To normalize the text before tokenization</param>
114
114
/// <returns>The tokenizer</returns>
115
- public static Tokenizer CreateByModelName (
115
+ public static Tokenizer CreateTokenizerForModel (
116
116
string modelName ,
117
117
Stream vocabStream ,
118
118
IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
@@ -124,7 +124,7 @@ public static Tokenizer CreateByModelName(
124
124
throw new ArgumentNullException ( nameof ( modelName ) ) ;
125
125
}
126
126
127
- ( Dictionary < string , int > SpecialTokens , Regex Regex ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
127
+ ( Dictionary < string , int > SpecialTokens , Regex Regex , string _ ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
128
128
129
129
if ( extraSpecialTokens is not null )
130
130
{
@@ -150,7 +150,7 @@ public static Tokenizer CreateByModelName(
150
150
/// <param name="normalizer">To normalize the text before tokenization</param>
151
151
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
152
152
/// <returns>The tokenizer</returns>
153
- public static async Task < Tokenizer > CreateByModelNameAsync (
153
+ public static async Task < Tokenizer > CreateTokenizerForModelAsync (
154
154
string modelName ,
155
155
Stream vocabStream ,
156
156
IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
@@ -163,7 +163,7 @@ public static async Task<Tokenizer> CreateByModelNameAsync(
163
163
throw new ArgumentNullException ( nameof ( modelName ) ) ;
164
164
}
165
165
166
- ( Dictionary < string , int > SpecialTokens , Regex Regex ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
166
+ ( Dictionary < string , int > SpecialTokens , Regex Regex , string _ ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
167
167
168
168
if ( extraSpecialTokens is not null )
169
169
{
@@ -738,31 +738,30 @@ private static ModelEncoding GetModelEncoding(string modelName)
738
738
return encoder ;
739
739
}
740
740
741
- internal static ( Dictionary < string , int > SpecialTokens , Regex Regex ) GetTiktokenConfigurations ( string modelName )
741
+ internal static ( Dictionary < string , int > SpecialTokens , Regex Regex , string Url ) GetTiktokenConfigurations ( string modelName )
742
742
{
743
743
ModelEncoding modelEncoding = GetModelEncoding ( modelName ) ;
744
744
745
745
switch ( modelEncoding )
746
746
{
747
747
case ModelEncoding . Cl100kBase :
748
748
return ( new Dictionary < string , int >
749
- { { EndOfText , 100257 } , { FimPrefix , 100258 } , { FimMiddle , 100259 } , { FimSuffix , 100260 } , { EndOfPrompt , 100276 } } , Cl100kBaseRegex ( ) ) ;
749
+ { { EndOfText , 100257 } , { FimPrefix , 100258 } , { FimMiddle , 100259 } , { FimSuffix , 100260 } , { EndOfPrompt , 100276 } } , Cl100kBaseRegex ( ) , Cl100kBaseVocabUrl ) ;
750
750
751
751
case ModelEncoding . P50kBase :
752
- return ( new Dictionary < string , int > { { EndOfText , 50256 } } , P50kBaseRegex ( ) ) ;
752
+ return ( new Dictionary < string , int > { { EndOfText , 50256 } } , P50kBaseRegex ( ) , P50RanksUrl ) ;
753
753
754
754
case ModelEncoding . P50kEdit :
755
755
return ( new Dictionary < string , int >
756
- { { EndOfText , 50256 } , { FimPrefix , 50281 } , { FimMiddle , 50282 } , { FimSuffix , 50283 } } , P50kBaseRegex ( ) ) ;
756
+ { { EndOfText , 50256 } , { FimPrefix , 50281 } , { FimMiddle , 50282 } , { FimSuffix , 50283 } } , P50kBaseRegex ( ) , P50RanksUrl ) ;
757
757
758
758
case ModelEncoding . R50kBase :
759
- return ( new Dictionary < string , int > { { EndOfText , 50256 } } , P50kBaseRegex ( ) ) ;
759
+ return ( new Dictionary < string , int > { { EndOfText , 50256 } } , P50kBaseRegex ( ) , R50RanksUrl ) ;
760
760
761
761
case ModelEncoding . GPT2 :
762
- return ( new Dictionary < string , int > { { EndOfText , 50256 } , } , P50kBaseRegex ( ) ) ;
762
+ return ( new Dictionary < string , int > { { EndOfText , 50256 } , } , P50kBaseRegex ( ) , GPT2Url ) ;
763
763
764
764
default :
765
- Debug . Assert ( false , $ "Unexpected encoder [{ modelEncoding } ]") ;
766
765
throw new NotSupportedException ( $ "The model '{ modelName } ' is not supported.") ;
767
766
}
768
767
}
@@ -775,22 +774,64 @@ internal static (Dictionary<string, int> SpecialTokens, Regex Regex) GetTiktoken
775
774
/// <param name="normalizer">To normalize the text before tokenization</param>
776
775
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
777
776
/// <returns>The tokenizer</returns>
778
- public static Task < Tokenizer > CreateByModelNameAsync (
777
+ public static Task < Tokenizer > CreateTokenizerForModelAsync (
779
778
string modelName ,
780
779
IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
781
780
Normalizer ? normalizer = null ,
782
781
CancellationToken cancellationToken = default )
783
782
{
784
783
try
785
784
{
786
- return CreateByEncoderNameAsync ( modelName , GetModelEncoding ( modelName ) , extraSpecialTokens , normalizer , cancellationToken ) ;
785
+ return CreateByEncoderNameAsync ( GetModelEncoding ( modelName ) , extraSpecialTokens , normalizer , cancellationToken ) ;
787
786
}
788
787
catch ( Exception ex )
789
788
{
790
789
return Task . FromException < Tokenizer > ( ex ) ;
791
790
}
792
791
}
793
792
793
+ /// <summary>
794
+ /// Create tokenizer based on model name
795
+ /// </summary>
796
+ /// <param name="modelName">Model name</param>
797
+ /// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
798
+ /// <param name="normalizer">To normalize the text before tokenization</param>
799
+ /// <returns>The tokenizer</returns>
800
+ public static Tokenizer CreateTokenizerForModel (
801
+ string modelName ,
802
+ IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
803
+ Normalizer ? normalizer = null )
804
+ {
805
+ if ( string . IsNullOrEmpty ( modelName ) )
806
+ {
807
+ throw new ArgumentNullException ( nameof ( modelName ) ) ;
808
+ }
809
+
810
+ ( Dictionary < string , int > SpecialTokens , Regex Regex , string Url ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
811
+
812
+ if ( extraSpecialTokens is not null )
813
+ {
814
+ foreach ( var extraSpecialToken in extraSpecialTokens )
815
+ {
816
+ tiktokenConfiguration . SpecialTokens . Add ( extraSpecialToken . Key , extraSpecialToken . Value ) ;
817
+ }
818
+ }
819
+
820
+ if ( ! _tiktokenCache . TryGetValue ( tiktokenConfiguration . Url ,
821
+ out ( Dictionary < ReadOnlyMemory < byte > , int > encoder , Dictionary < StringSpanOrdinalKey , int > vocab , Dictionary < int , ReadOnlyMemory < byte > > decoder ) cache ) )
822
+ {
823
+ using Stream stream = Helpers . GetStream ( _httpClient , tiktokenConfiguration . Url ) ;
824
+ cache = LoadTikTokenBpeAsync ( stream , useAsync : false ) . GetAwaiter ( ) . GetResult ( ) ;
825
+
826
+ _tiktokenCache . TryAdd ( tiktokenConfiguration . Url , cache ) ;
827
+ }
828
+
829
+ return new Tokenizer (
830
+ new Tiktoken ( cache . encoder , cache . decoder , cache . vocab , tiktokenConfiguration . SpecialTokens , LruCache < int [ ] > . DefaultCacheSize ) ,
831
+ new TikTokenPreTokenizer ( tiktokenConfiguration . Regex , tiktokenConfiguration . SpecialTokens ) ,
832
+ normalizer ) ;
833
+ }
834
+
794
835
// Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
795
836
796
837
private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" ;
@@ -818,15 +859,13 @@ public static Task<Tokenizer> CreateByModelNameAsync(
818
859
/// <summary>
819
860
/// Create tokenizer based on encoder name and extra special tokens
820
861
/// </summary>
821
- /// <param name="modelName">Model name</param>
822
862
/// <param name="modelEncoding">Encoder label</param>
823
863
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
824
864
/// <param name="normalizer">To normalize the text before tokenization</param>
825
865
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
826
866
/// <returns>The tokenizer</returns>
827
867
/// <exception cref="NotSupportedException">Throws if the model name is not supported</exception>
828
868
private static Task < Tokenizer > CreateByEncoderNameAsync (
829
- string modelName ,
830
869
ModelEncoding modelEncoding ,
831
870
IReadOnlyDictionary < string , int > ? extraSpecialTokens ,
832
871
Normalizer ? normalizer ,
@@ -857,8 +896,7 @@ private static Task<Tokenizer> CreateByEncoderNameAsync(
857
896
return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , GPT2Url , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
858
897
859
898
default :
860
- Debug . Assert ( false , $ "Unexpected encoder [{ modelEncoding } ]") ;
861
- throw new NotSupportedException ( $ "The model '{ modelName } ' is not supported.") ;
899
+ throw new NotSupportedException ( $ "The encoder '{ modelEncoding } ' is not supported.") ;
862
900
}
863
901
}
864
902
@@ -894,7 +932,7 @@ private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
894
932
{
895
933
using ( Stream stream = await Helpers . GetStreamAsync ( _httpClient , mergeableRanksFileUrl , cancellationToken ) . ConfigureAwait ( false ) )
896
934
{
897
- cache = await Tiktoken . LoadTikTokenBpeAsync ( stream , useAsync : true , cancellationToken ) . ConfigureAwait ( false ) ;
935
+ cache = await LoadTikTokenBpeAsync ( stream , useAsync : true , cancellationToken ) . ConfigureAwait ( false ) ;
898
936
}
899
937
900
938
_tiktokenCache . TryAdd ( mergeableRanksFileUrl , cache ) ;
0 commit comments