Skip to content

Commit cea9d90

Browse files
tarekghericstj
andauthored
Add Tiktoken Synchronous Creation Using Model Name (#7080)
* Add Tiktoken Synchronous Creation Using Model Name * Add RemoteExecutor to Tokenizers tests * Address the feedback * Add tests --------- Co-authored-by: Eric StJohn <ericstj@microsoft.com>
1 parent c69c4a0 commit cea9d90

File tree

8 files changed

+114
-30
lines changed

8 files changed

+114
-30
lines changed

eng/Version.Details.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
<Uri>https://github.com/dotnet/arcade</Uri>
2424
<Sha>812d978c303174dc1aa305d7359e79053d7d4971</Sha>
2525
</Dependency>
26+
<!-- Stay on package 8.0 until we stop testing for net6.0
27+
<Dependency Name="Microsoft.DotNet.RemoteExecutor" Version="9.0.0-beta.24165.3">
28+
<Uri>https://github.com/dotnet/arcade</Uri>
29+
<Sha>812d978c303174dc1aa305d7359e79053d7d4971</Sha>
30+
</Dependency> -->
2631
<Dependency Name="Microsoft.DotNet.SwaggerGenerator.MSBuild" Version="9.0.0-beta.24165.3">
2732
<Uri>https://github.com/dotnet/arcade</Uri>
2833
<Sha>812d978c303174dc1aa305d7359e79053d7d4971</Sha>

eng/Versions.props

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
<DotNetRuntime80Version>8.0.1</DotNetRuntime80Version>
8080
<FluentAssertionVersion>5.10.2</FluentAssertionVersion>
8181
<MicrosoftCodeAnalysisTestingVersion>1.1.2-beta1.23431.1</MicrosoftCodeAnalysisTestingVersion>
82+
<MicrosoftDotNetRemoteExecutorVersion>8.0.0-beta.24165.4</MicrosoftDotNetRemoteExecutorVersion>
8283
<MicrosoftDotNetXUnitExtensionsVersion>9.0.0-beta.24165.3</MicrosoftDotNetXUnitExtensionsVersion>
8384
<MicrosoftExtensionsDependencyModelVersion>2.1.0</MicrosoftExtensionsDependencyModelVersion>
8485
<MicrosoftExtensionsTestVersion>3.0.1</MicrosoftExtensionsTestVersion>

src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTo
112112
/// <param name="cacheSize">The size of the cache to use.</param>
113113
/// <param name="normalizer">To normalize the text before tokenization</param>
114114
/// <returns>The tokenizer</returns>
115-
public static Tokenizer CreateByModelName(
115+
public static Tokenizer CreateTokenizerForModel(
116116
string modelName,
117117
Stream vocabStream,
118118
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
@@ -124,7 +124,7 @@ public static Tokenizer CreateByModelName(
124124
throw new ArgumentNullException(nameof(modelName));
125125
}
126126

127-
(Dictionary<string, int> SpecialTokens, Regex Regex) tiktokenConfiguration = GetTiktokenConfigurations(modelName);
127+
(Dictionary<string, int> SpecialTokens, Regex Regex, string _) tiktokenConfiguration = GetTiktokenConfigurations(modelName);
128128

129129
if (extraSpecialTokens is not null)
130130
{
@@ -150,7 +150,7 @@ public static Tokenizer CreateByModelName(
150150
/// <param name="normalizer">To normalize the text before tokenization</param>
151151
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
152152
/// <returns>The tokenizer</returns>
153-
public static async Task<Tokenizer> CreateByModelNameAsync(
153+
public static async Task<Tokenizer> CreateTokenizerForModelAsync(
154154
string modelName,
155155
Stream vocabStream,
156156
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
@@ -163,7 +163,7 @@ public static async Task<Tokenizer> CreateByModelNameAsync(
163163
throw new ArgumentNullException(nameof(modelName));
164164
}
165165

166-
(Dictionary<string, int> SpecialTokens, Regex Regex) tiktokenConfiguration = GetTiktokenConfigurations(modelName);
166+
(Dictionary<string, int> SpecialTokens, Regex Regex, string _) tiktokenConfiguration = GetTiktokenConfigurations(modelName);
167167

168168
if (extraSpecialTokens is not null)
169169
{
@@ -738,31 +738,30 @@ private static ModelEncoding GetModelEncoding(string modelName)
738738
return encoder;
739739
}
740740

741-
internal static (Dictionary<string, int> SpecialTokens, Regex Regex) GetTiktokenConfigurations(string modelName)
741+
internal static (Dictionary<string, int> SpecialTokens, Regex Regex, string Url) GetTiktokenConfigurations(string modelName)
742742
{
743743
ModelEncoding modelEncoding = GetModelEncoding(modelName);
744744

745745
switch (modelEncoding)
746746
{
747747
case ModelEncoding.Cl100kBase:
748748
return (new Dictionary<string, int>
749-
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }, Cl100kBaseRegex());
749+
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }, Cl100kBaseRegex(), Cl100kBaseVocabUrl);
750750

751751
case ModelEncoding.P50kBase:
752-
return (new Dictionary<string, int> { { EndOfText, 50256 } }, P50kBaseRegex());
752+
return (new Dictionary<string, int> { { EndOfText, 50256 } }, P50kBaseRegex(), P50RanksUrl);
753753

754754
case ModelEncoding.P50kEdit:
755755
return (new Dictionary<string, int>
756-
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }, P50kBaseRegex());
756+
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }, P50kBaseRegex(), P50RanksUrl);
757757

758758
case ModelEncoding.R50kBase:
759-
return (new Dictionary<string, int> { { EndOfText, 50256 } }, P50kBaseRegex());
759+
return (new Dictionary<string, int> { { EndOfText, 50256 } }, P50kBaseRegex(), R50RanksUrl);
760760

761761
case ModelEncoding.GPT2:
762-
return (new Dictionary<string, int> { { EndOfText, 50256 }, }, P50kBaseRegex());
762+
return (new Dictionary<string, int> { { EndOfText, 50256 }, }, P50kBaseRegex(), GPT2Url);
763763

764764
default:
765-
Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
766765
throw new NotSupportedException($"The model '{modelName}' is not supported.");
767766
}
768767
}
@@ -775,22 +774,64 @@ internal static (Dictionary<string, int> SpecialTokens, Regex Regex) GetTiktoken
775774
/// <param name="normalizer">To normalize the text before tokenization</param>
776775
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
777776
/// <returns>The tokenizer</returns>
778-
public static Task<Tokenizer> CreateByModelNameAsync(
777+
public static Task<Tokenizer> CreateTokenizerForModelAsync(
779778
string modelName,
780779
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
781780
Normalizer? normalizer = null,
782781
CancellationToken cancellationToken = default)
783782
{
784783
try
785784
{
786-
return CreateByEncoderNameAsync(modelName, GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken);
785+
return CreateByEncoderNameAsync(GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken);
787786
}
788787
catch (Exception ex)
789788
{
790789
return Task.FromException<Tokenizer>(ex);
791790
}
792791
}
793792

793+
/// <summary>
794+
/// Create tokenizer based on model name
795+
/// </summary>
796+
/// <param name="modelName">Model name</param>
797+
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
798+
/// <param name="normalizer">To normalize the text before tokenization</param>
799+
/// <returns>The tokenizer</returns>
800+
public static Tokenizer CreateTokenizerForModel(
801+
string modelName,
802+
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
803+
Normalizer? normalizer = null)
804+
{
805+
if (string.IsNullOrEmpty(modelName))
806+
{
807+
throw new ArgumentNullException(nameof(modelName));
808+
}
809+
810+
(Dictionary<string, int> SpecialTokens, Regex Regex, string Url) tiktokenConfiguration = GetTiktokenConfigurations(modelName);
811+
812+
if (extraSpecialTokens is not null)
813+
{
814+
foreach (var extraSpecialToken in extraSpecialTokens)
815+
{
816+
tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
817+
}
818+
}
819+
820+
if (!_tiktokenCache.TryGetValue(tiktokenConfiguration.Url,
821+
out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, int> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder) cache))
822+
{
823+
using Stream stream = Helpers.GetStream(_httpClient, tiktokenConfiguration.Url);
824+
cache = LoadTikTokenBpeAsync(stream, useAsync: false).GetAwaiter().GetResult();
825+
826+
_tiktokenCache.TryAdd(tiktokenConfiguration.Url, cache);
827+
}
828+
829+
return new Tokenizer(
830+
new Tiktoken(cache.encoder, cache.decoder, cache.vocab, tiktokenConfiguration.SpecialTokens, LruCache<int[]>.DefaultCacheSize),
831+
new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
832+
normalizer);
833+
}
834+
794835
// Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
795836

796837
private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
@@ -818,15 +859,13 @@ public static Task<Tokenizer> CreateByModelNameAsync(
818859
/// <summary>
819860
/// Create tokenizer based on encoder name and extra special tokens
820861
/// </summary>
821-
/// <param name="modelName">Model name</param>
822862
/// <param name="modelEncoding">Encoder label</param>
823863
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
824864
/// <param name="normalizer">To normalize the text before tokenization</param>
825865
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
826866
/// <returns>The tokenizer</returns>
827867
/// <exception cref="NotSupportedException">Throws if the model name is not supported</exception>
828868
private static Task<Tokenizer> CreateByEncoderNameAsync(
829-
string modelName,
830869
ModelEncoding modelEncoding,
831870
IReadOnlyDictionary<string, int>? extraSpecialTokens,
832871
Normalizer? normalizer,
@@ -857,8 +896,7 @@ private static Task<Tokenizer> CreateByEncoderNameAsync(
857896
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
858897

859898
default:
860-
Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
861-
throw new NotSupportedException($"The model '{modelName}' is not supported.");
899+
throw new NotSupportedException($"The encoder '{modelEncoding}' is not supported.");
862900
}
863901
}
864902

@@ -894,7 +932,7 @@ private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
894932
{
895933
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
896934
{
897-
cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
935+
cache = await LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
898936
}
899937

900938
_tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);

src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,16 @@ internal static class Helpers
1818
public static ValueTask<string?> ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) =>
1919
reader.ReadLineAsync(cancellationToken);
2020

21-
public static Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) =>
21+
public static Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken = default) =>
2222
client.GetStreamAsync(url, cancellationToken);
2323

24+
public static Stream GetStream(HttpClient client, string url)
25+
{
26+
HttpResponseMessage response = client.Send(new HttpRequestMessage(HttpMethod.Get, url), HttpCompletionOption.ResponseHeadersRead);
27+
response.EnsureSuccessStatusCode();
28+
return response.Content.ReadAsStream();
29+
}
30+
2431
public static byte[] FromBase64String(string base64String, int offset, int length)
2532
{
2633
if (!Base64.IsValid(base64String.AsSpan(offset, length), out int decodedLength))

src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@ public static ValueTask<string> ReadLineAsync(StreamReader reader, CancellationT
1818
return new ValueTask<string>(reader.ReadLineAsync());
1919
}
2020

21-
public static async Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken)
21+
public static async Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken = default)
2222
{
2323
HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
2424
response.EnsureSuccessStatusCode();
2525
return await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
2626
}
2727

28+
public static Stream GetStream(HttpClient client, string url) => client.GetStreamAsync(url).GetAwaiter().GetResult();
29+
2830
public static byte[] FromBase64String(string base64String, int offset, int length) => Convert.FromBase64String(base64String.Substring(offset, length));
2931

3032
// Not support signed number

test/Microsoft.ML.TestFramework/Microsoft.ML.TestFramework.csproj

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
</ItemGroup>
3535

3636
<ItemGroup Condition="'$(TargetFramework)' != 'net462'">
37-
<!-- This reference will be updated to use DARC in a subsequent PR so we can leave the version here as is -->
38-
<PackageReference Include="Microsoft.DotNet.RemoteExecutor" Version="7.0.0-beta.21456.1" />
37+
<PackageReference Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotNetRemoteExecutorVersion)" />
3938
</ItemGroup>
4039

4140
</Project>

test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
</ItemGroup>
3737

3838
<ItemGroup>
39+
<PackageReference Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotNetRemoteExecutorVersion)" />
3940
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
4041
</ItemGroup>
4142

test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.DotNet.RemoteExecutor;
56
using Microsoft.ML.Tokenizers;
67
using System;
78
using System.Collections.Generic;
@@ -25,11 +26,11 @@ public class TiktokenTests
2526
{ IMEnd, 100265},
2627
};
2728

28-
public static Tokenizer GPT4 { get; } = Tiktoken.CreateByModelNameAsync("gpt-4", _specialTokens).GetAwaiter().GetResult();
29-
public static Tokenizer GPT2 { get; } = Tiktoken.CreateByModelNameAsync("gpt2").GetAwaiter().GetResult();
30-
public static Tokenizer P50kBase { get; } = Tiktoken.CreateByModelNameAsync("text-davinci-003").GetAwaiter().GetResult();
31-
public static Tokenizer R50kBase { get; } = Tiktoken.CreateByModelNameAsync("ada").GetAwaiter().GetResult();
32-
public static Tokenizer P50kEdit { get; } = Tiktoken.CreateByModelNameAsync("text-davinci-edit-001").GetAwaiter().GetResult();
29+
public static Tokenizer GPT4 { get; } = Tiktoken.CreateTokenizerForModelAsync("gpt-4", _specialTokens).GetAwaiter().GetResult();
30+
public static Tokenizer GPT2 { get; } = Tiktoken.CreateTokenizerForModelAsync("gpt2").GetAwaiter().GetResult();
31+
public static Tokenizer P50kBase { get; } = Tiktoken.CreateTokenizerForModelAsync("text-davinci-003").GetAwaiter().GetResult();
32+
public static Tokenizer R50kBase { get; } = Tiktoken.CreateTokenizerForModelAsync("ada").GetAwaiter().GetResult();
33+
public static Tokenizer P50kEdit { get; } = Tiktoken.CreateTokenizerForModelAsync("text-davinci-edit-001").GetAwaiter().GetResult();
3334

3435
[Fact]
3536
public async void TestTokenizerCreation()
@@ -64,15 +65,18 @@ public async void TestTokenizerCreation()
6465

6566
using (Stream stream = File.OpenRead(tokenizerDataFileName))
6667
{
67-
tokenizer = Tiktoken.CreateByModelName("gpt-4", stream);
68+
tokenizer = Tiktoken.CreateTokenizerForModel("gpt-4", stream);
6869
}
6970
TestGPT4TokenizationEncoding(tokenizer);
7071

7172
using (Stream stream = File.OpenRead(tokenizerDataFileName))
7273
{
73-
tokenizer = await Tiktoken.CreateByModelNameAsync("gpt-3.5-turbo", stream);
74+
tokenizer = await Tiktoken.CreateTokenizerForModelAsync("gpt-3.5-turbo", stream);
7475
}
7576
TestGPT4TokenizationEncoding(tokenizer);
77+
78+
tokenizer = Tiktoken.CreateTokenizerForModel("gpt-4");
79+
TestGPT4TokenizationEncoding(tokenizer);
7680
}
7781
finally
7882
{
@@ -298,11 +302,38 @@ public void TestEncodeR50kBase()
298302
[InlineData("gpt2")]
299303
public async void TestAllSupportedModelNames(string modelName)
300304
{
301-
Tokenizer tokenizer = await Tiktoken.CreateByModelNameAsync(modelName);
305+
Tokenizer tokenizer = Tiktoken.CreateTokenizerForModel(modelName);
306+
Assert.NotNull(tokenizer.Model);
307+
Assert.NotNull(tokenizer.PreTokenizer);
308+
309+
tokenizer = await Tiktoken.CreateTokenizerForModelAsync(modelName);
302310
Assert.NotNull(tokenizer.Model);
303311
Assert.NotNull(tokenizer.PreTokenizer);
304312
}
305313

314+
[InlineData("gpt-4")]
315+
[InlineData("text-davinci-003")]
316+
[InlineData("text-curie-001")]
317+
[InlineData("text-davinci-edit-001")]
318+
[ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
319+
public void TestCreationUsingModel(string modelName)
320+
{
321+
// Execute remotely to ensure no caching is used.
322+
RemoteExecutor.Invoke(static async (name) =>
323+
{
324+
Tokenizer tokenizer = await Tiktoken.CreateTokenizerForModelAsync(name);
325+
Assert.NotNull(tokenizer.Model);
326+
Assert.NotNull(tokenizer.PreTokenizer);
327+
}, modelName).Dispose();
328+
329+
RemoteExecutor.Invoke(static (name) =>
330+
{
331+
Tokenizer tokenizer = Tiktoken.CreateTokenizerForModel(name);
332+
Assert.NotNull(tokenizer.Model);
333+
Assert.NotNull(tokenizer.PreTokenizer);
334+
}, modelName).Dispose();
335+
}
336+
306337
// Test running copy the test data files to the output folder but sometimes the file content is mutated replacing '\n' with '\r\n'.
307338
// This method reads the file and removes the extra inserted '\r' characters. Having '\r' in the file content will cause the tests to fail.
308339
private string ReadAndSanitizeFile(string path)

0 commit comments

Comments
 (0)