Skip to content

Commit f9332d5

Browse files
simplify creation of tuned model
1 parent e8acad1 commit f9332d5

File tree

5 files changed

+127
-40
lines changed

5 files changed

+127
-40
lines changed

README.md

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -184,40 +184,28 @@ var projectId = "your_google_project_id"; // the ID of a project, not its name.
184184
var accessToken = "your_access_token"; // use `gcloud auth application-default print-access-token` to get it.
185185
var model = new GenerativeModel(apiKey: null, model: Model.Gemini10Pro001)
186186
{
187-
AccessToken = accessToken,
188-
ProjectId = projectId
187+
AccessToken = accessToken, ProjectId = projectId
189188
};
190-
var request = new CreateTunedModelRequest()
191-
{
192-
BaseModel = $"models/{Model.Gemini10Pro001}",
193-
DisplayName = "Autogenerated Test model",
194-
TuningTask = new()
195-
{
196-
Hyperparameters = new() { BatchSize = 2, LearningRate = 0.001f, EpochCount = 3 },
197-
TrainingData = new()
198-
{
199-
Examples = new()
200-
{
201-
Examples = new()
202-
{
203-
new TuningExample() { TextInput = "1", Output = "2" },
204-
new TuningExample() { TextInput = "3", Output = "4" },
205-
new TuningExample() { TextInput = "-3", Output = "-2" },
206-
new TuningExample() { TextInput = "twenty two", Output = "twenty three" },
207-
new TuningExample() { TextInput = "two hundred", Output = "two hundred one" },
208-
new TuningExample() { TextInput = "ninety nine", Output = "one hundred" },
209-
new TuningExample() { TextInput = "8", Output = "9" },
210-
new TuningExample() { TextInput = "-98", Output = "-97" },
211-
new TuningExample() { TextInput = "1,000", Output = "1,001" },
212-
new TuningExample() { TextInput = "thirteen", Output = "fourteen" },
213-
new TuningExample() { TextInput = "seven", Output = "eight" },
214-
}
215-
}
216-
}
217-
}
189+
var parameters = new HyperParameters() { BatchSize = 2, LearningRate = 0.001f, EpochCount = 3 };
190+
var dataset = new List<TuningExample>
191+
{
192+
new() { TextInput = "1", Output = "2" },
193+
new() { TextInput = "3", Output = "4" },
194+
new() { TextInput = "-3", Output = "-2" },
195+
new() { TextInput = "twenty two", Output = "twenty three" },
196+
new() { TextInput = "two hundred", Output = "two hundred one" },
197+
new() { TextInput = "ninety nine", Output = "one hundred" },
198+
new() { TextInput = "8", Output = "9" },
199+
new() { TextInput = "-98", Output = "-97" },
200+
new() { TextInput = "1,000", Output = "1,001" },
201+
new() { TextInput = "thirteen", Output = "fourteen" },
202+
new() { TextInput = "seven", Output = "eight" },
218203
};
204+
var request = new CreateTunedModelRequest(Model.Gemini10Pro001,
205+
"Simply autogenerated Test model",
206+
dataset,
207+
parameters);
219208

220-
// Act
221209
var response = await model.CreateTunedModel(request);
222210
Console.WriteLine($"Name: {response.Name}");
223211
Console.WriteLine($"Model: {response.Metadata.TunedModel} (Steps: {response.Metadata.TotalSteps})");

src/Mscc.GenerativeAI/GenerativeModel.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ private async Task<List<ModelResponse>> ListTunedModels(int? pageSize = null,
331331
/// <exception cref="NotSupportedException"></exception>
332332
public async Task<CreateTunedModelResponse> CreateTunedModel(CreateTunedModelRequest request)
333333
{
334-
if (!(_model is (string)GenerativeAI.Model.BisonText001 ||
335-
_model is (string)GenerativeAI.Model.Gemini10Pro001))
334+
if (!(_model.Equals($"models/{GenerativeAI.Model.BisonText001}", StringComparison.InvariantCultureIgnoreCase) ||
335+
_model.Equals($"models/{GenerativeAI.Model.Gemini10Pro001}", StringComparison.InvariantCultureIgnoreCase)))
336336
{
337337
throw new NotSupportedException();
338338
}
@@ -664,7 +664,7 @@ public IAsyncEnumerable<GenerateContentResponse> GenerateContentStream(List<IPar
664664
public async Task<GenerateAnswerResponse> GenerateAnswer(GenerateAnswerRequest? request)
665665
{
666666
if (request == null) throw new ArgumentNullException(nameof(request));
667-
if (_model != (string)GenerativeAI.Model.AttributedQuestionAnswering)
667+
if (!_model.Equals($"models/{GenerativeAI.Model.AttributedQuestionAnswering}", StringComparison.InvariantCultureIgnoreCase))
668668
{
669669
throw new NotSupportedException();
670670
}
@@ -711,7 +711,7 @@ public async Task<GenerateAnswerResponse> GenerateAnswer(GenerateAnswerRequest?
711711
public async Task<EmbedContentResponse> EmbedContent(EmbedContentRequest request, TaskType? taskType = null, string? title = null)
712712
{
713713
if (request == null) throw new ArgumentNullException(nameof(request));
714-
if (_model != (string)GenerativeAI.Model.Embedding)
714+
if (!_model.Equals($"models/{GenerativeAI.Model.Embedding}", StringComparison.InvariantCultureIgnoreCase))
715715
{
716716
throw new NotSupportedException();
717717
}
@@ -736,7 +736,7 @@ public async Task<EmbedContentResponse> EmbedContent(EmbedContentRequest request
736736
public async Task<EmbedContentResponse> EmbedContent(string? prompt, TaskType? taskType = null, string? title = null)
737737
{
738738
if (prompt == null) throw new ArgumentNullException(nameof(prompt));
739-
if (_model != (string)GenerativeAI.Model.Embedding)
739+
if (!_model.Equals($"models/{GenerativeAI.Model.Embedding}", StringComparison.InvariantCultureIgnoreCase))
740740
{
741741
throw new NotSupportedException();
742742
}
@@ -759,7 +759,7 @@ public async Task<EmbedContentResponse> EmbedContent(string? prompt, TaskType? t
759759
public async Task<EmbedContentResponse> BatchEmbedContent(List<EmbedContentRequest> requests)
760760
{
761761
if (requests == null) throw new ArgumentNullException(nameof(requests));
762-
if (_model != (string)GenerativeAI.Model.Embedding)
762+
if (!_model.Equals($"models/{GenerativeAI.Model.Embedding}", StringComparison.InvariantCultureIgnoreCase))
763763
{
764764
throw new NotSupportedException();
765765
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,64 @@
1+
#if NET472_OR_GREATER || NETSTANDARD2_0
2+
using System;
3+
using System.Collections.Generic;
4+
#endif
5+
16
namespace Mscc.GenerativeAI
27
{
8+
/// <summary>
9+
/// Request to create a tuned model.
10+
/// </summary>
311
public class CreateTunedModelRequest
412
{
13+
/// <summary>
14+
/// The name to display for this model in user interfaces. The display name must be up to 40 characters including spaces.
15+
/// </summary>
516
public string DisplayName { get; set; }
17+
/// <summary>
18+
/// The name of the Model to tune. Example: models/text-bison-001
19+
/// </summary>
620
public string BaseModel { get; set; }
21+
/// <summary>
22+
/// Tuning tasks that create tuned models.
23+
/// </summary>
724
public TuningTask TuningTask { get; set; }
25+
26+
/// <summary>
27+
/// Constructor.
28+
/// </summary>
29+
public CreateTunedModelRequest()
30+
{
31+
TuningTask = new TuningTask
32+
{
33+
TrainingData = new()
34+
{
35+
Examples = new()
36+
{
37+
Examples = new() { }
38+
}
39+
}
40+
};
41+
}
42+
43+
/// <summary>
44+
/// Creates a request for a tuned model.
45+
/// </summary>
46+
/// <param name="model">Model to use.</param>
47+
/// <param name="name">Name of the tuned model.</param>
48+
/// <param name="dataset">Dataset for training or validation.</param>
49+
/// <param name="parameters">Immutable. Hyperparameters controlling the tuning process. If not provided, default values will be used.</param>
50+
/// <exception cref="ArgumentNullException"></exception>
51+
public CreateTunedModelRequest(string model, string name,
52+
List<TuningExample>? dataset = null,
53+
HyperParameters? parameters = null) : this()
54+
{
55+
if (model is null) throw new ArgumentNullException(nameof(model));
56+
if (name is null) throw new ArgumentNullException(nameof(name));
57+
58+
BaseModel = model.SanitizeModelName();
59+
DisplayName = name.Trim();
60+
TuningTask.Hyperparameters = parameters;
61+
TuningTask.TrainingData.Examples.Examples = dataset ?? new();
62+
}
863
}
964
}

src/Mscc.GenerativeAI/Types/CreateTunedModelResponse.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
namespace Mscc.GenerativeAI
44
{
5+
/// <summary>
6+
/// Response of a newly created tuned model.
7+
/// </summary>
58
public class CreateTunedModelResponse
69
{
710
public string Name { get; set; }
@@ -13,6 +16,9 @@ public class CreateTunedModelMetadata
1316
{
1417
public string Type { get; set; }
1518
public int TotalSteps { get; set; }
19+
/// <summary>
20+
/// A fine-tuned model created using ModelService.CreateTunedModel.
21+
/// </summary>
1622
public string TunedModel { get; set; }
1723
}
1824
}

tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,8 +1037,7 @@ public async void Create_Tuned_Model()
10371037
// Arrange
10381038
var model = new GenerativeModel(apiKey: null, model: Model.Gemini10Pro001)
10391039
{
1040-
AccessToken = fixture.AccessToken,
1041-
ProjectId = fixture.ProjectId
1040+
AccessToken = fixture.AccessToken, ProjectId = fixture.ProjectId
10421041
};
10431042
var request = new CreateTunedModelRequest()
10441043
{
@@ -1069,7 +1068,7 @@ public async void Create_Tuned_Model()
10691068
}
10701069
}
10711070
};
1072-
1071+
10731072
// Act
10741073
var response = await model.CreateTunedModel(request);
10751074

@@ -1081,6 +1080,45 @@ public async void Create_Tuned_Model()
10811080
output.WriteLine($"Model: {response.Metadata.TunedModel} (Steps: {response.Metadata.TotalSteps})");
10821081
}
10831082

1083+
[Fact]
1084+
public async void Create_Tuned_Model_Simply()
1085+
{
1086+
// Arrange
1087+
var model = new GenerativeModel(apiKey: null, model: Model.Gemini10Pro001)
1088+
{
1089+
AccessToken = fixture.AccessToken, ProjectId = fixture.ProjectId
1090+
};
1091+
var parameters = new HyperParameters() { BatchSize = 2, LearningRate = 0.001f, EpochCount = 3 };
1092+
var dataset = new List<TuningExample>
1093+
{
1094+
new() { TextInput = "1", Output = "2" },
1095+
new() { TextInput = "3", Output = "4" },
1096+
new() { TextInput = "-3", Output = "-2" },
1097+
new() { TextInput = "twenty two", Output = "twenty three" },
1098+
new() { TextInput = "two hundred", Output = "two hundred one" },
1099+
new() { TextInput = "ninety nine", Output = "one hundred" },
1100+
new() { TextInput = "8", Output = "9" },
1101+
new() { TextInput = "-98", Output = "-97" },
1102+
new() { TextInput = "1,000", Output = "1,001" },
1103+
new() { TextInput = "thirteen", Output = "fourteen" },
1104+
new() { TextInput = "seven", Output = "eight" },
1105+
};
1106+
var request = new CreateTunedModelRequest(Model.Gemini10Pro001,
1107+
"Simply autogenerated Test model",
1108+
dataset,
1109+
parameters);
1110+
1111+
// Act
1112+
var response = await model.CreateTunedModel(request);
1113+
1114+
// Assert
1115+
response.Should().NotBeNull();
1116+
response.Name.Should().NotBeNull();
1117+
response.Metadata.Should().NotBeNull();
1118+
output.WriteLine($"Name: {response.Name}");
1119+
output.WriteLine($"Model: {response.Metadata.TunedModel} (Steps: {response.Metadata.TotalSteps})");
1120+
}
1121+
10841122
[Fact]
10851123
public async void Delete_Tuned_Model()
10861124
{

0 commit comments

Comments
 (0)