Skip to content

Commit 01fe8f4

Browse files
use EnvVars for default values; refactor Name; extend tests
1 parent e0fd8bb commit 01fe8f4

13 files changed

+182
-54
lines changed

src/Mscc.GenerativeAI.Google/Mscc.GenerativeAI.Google.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
</ItemGroup>
6262

6363
<ItemGroup>
64-
<Content Include="client_secrets.json">
64+
<Content Include="client_secret.json">
6565
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
6666
</Content>
6767
<Content Include="key.p12">

src/Mscc.GenerativeAI/Constants/Model.cs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,5 @@ public static class Model
3838
public const string GeckoEmbedding = "embedding-gecko-001";
3939
public const string Embedding = "embedding-001";
4040
public const string AttributedQuestionAnswering = "aqa";
41-
42-
public static string Sanitize(this string value)
43-
{
44-
if (value.StartsWith("model", StringComparison.InvariantCultureIgnoreCase))
45-
{
46-
var parts = value.Split(new char[] { '/' }, System.StringSplitOptions.RemoveEmptyEntries);
47-
value = parts.Last();
48-
}
49-
return value.ToLower();
50-
}
5141
}
5242
}

src/Mscc.GenerativeAI/GenerativeModel.cs

Lines changed: 116 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Threading;
1010
using System.Threading.Tasks;
1111
#endif
12+
using System.Diagnostics;
1213
using System.Net;
1314
using System.Net.Http.Headers;
1415
using System.Runtime.CompilerServices;
@@ -29,11 +30,12 @@ public class GenerativeModel
2930
private readonly bool _useHeaderApiKey;
3031
private readonly bool _useHeaderProjectId;
3132
private readonly string _model;
32-
private readonly string _apiKey;
33-
private readonly string _projectId;
34-
private readonly string _region;
33+
private readonly string? _apiKey;
34+
private readonly string? _projectId;
35+
private readonly string? _region;
3536
private readonly string _publisher = "google";
3637
private readonly JsonSerializerOptions _options;
38+
private string? _accessToken;
3739
private List<SafetySetting>? _safetySettings;
3840
private GenerationConfig? _generationConfig;
3941
private List<Tool>? _tools;
@@ -109,16 +111,20 @@ private string Method
109111
}
110112
}
111113

112-
// Todo: Remove after ADC has been added.
113-
private string _accessToken;
114+
/// <summary>
115+
/// Returns the name of the model.
116+
/// </summary>
117+
/// <returns>Name of the model.</returns>
118+
public string Name => _model;
114119

115-
public string AccessToken
120+
public string? AccessToken
116121
{
117122
get => _accessToken;
118123
set
119124
{
120125
_accessToken = value;
121-
Client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", _accessToken);
126+
if (value != null)
127+
Client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", _accessToken);
122128
}
123129
}
124130

@@ -128,10 +134,28 @@ public string AccessToken
128134
public GenerativeModel()
129135
{
130136
_options = DefaultJsonSerializerOptions();
131-
// GOOGLE_APPLICATION_CREDENTIALS
132-
// Linux, macOS: $HOME /.config / gcloud / application_default_credentials.json
133-
// Windows: % APPDATA %\gcloud\application_default_credentials.json
134-
//var credentials = GoogleCredential.FromFile(Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "gcloud", "application_default_credentials.json"))
137+
_model = Environment.GetEnvironmentVariable("GOOGLE_AI_MODEL") ??
138+
Model.Gemini10Pro;
139+
_apiKey = Environment.GetEnvironmentVariable("GOOGLE_API_KEY");
140+
_projectId = Environment.GetEnvironmentVariable("GOOGLE_PROJECT_ID");
141+
_region = Environment.GetEnvironmentVariable("GOOGLE_REGION");
142+
AccessToken = Environment.GetEnvironmentVariable("GOOGLE_ACCESS_TOKEN") ??
143+
GetAccessTokenFromAdc();
144+
145+
var credentialsFile =
146+
Environment.GetEnvironmentVariable("GOOGLE_APPLICATION_CREDENTIALS") ??
147+
Environment.GetEnvironmentVariable("GOOGLE_WEB_CREDENTIALS") ??
148+
Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "gcloud",
149+
"application_default_credentials.json");
150+
if (File.Exists(credentialsFile))
151+
{
152+
using (var stream = new FileStream(credentialsFile, FileMode.Open, FileAccess.Read))
153+
{
154+
var json = JsonSerializer.DeserializeAsync<JsonElement>(stream, _options).Result;
155+
_projectId ??= json.GetValue("quota_project_id") ??
156+
json.GetValue("project_id");
157+
}
158+
} //var credentials = GoogleCredential.FromFile()
135159
}
136160

137161
// Todo: Add parameters for GenerationConfig, SafetySettings, Transport? and Tools
@@ -142,15 +166,15 @@ public GenerativeModel()
142166
/// <param name="model">Model to use (default: "gemini-pro")</param>
143167
/// <param name="generationConfig"></param>
144168
/// <param name="safetySettings"></param>
145-
public GenerativeModel(string apiKey = "",
146-
string model = Model.GeminiPro,
169+
public GenerativeModel(string? apiKey = null,
170+
string? model = null,
147171
GenerationConfig? generationConfig = null,
148172
List<SafetySetting>? safetySettings = null) : this()
149173
{
150-
_apiKey = apiKey;
151-
_model = model.Sanitize();
152-
_generationConfig = generationConfig;
153-
_safetySettings = safetySettings;
174+
_apiKey = apiKey ?? _apiKey;
175+
_model = model.SanitizeModelName() ?? _model;
176+
_generationConfig ??= generationConfig;
177+
_safetySettings ??= safetySettings;
154178

155179
if (!string.IsNullOrEmpty(apiKey))
156180
{
@@ -179,7 +203,7 @@ internal GenerativeModel(string projectId, string region,
179203
_useVertexAi = true;
180204
_projectId = projectId;
181205
_region = region;
182-
_model = model.Sanitize();
206+
_model = model.SanitizeModelName();
183207
_generationConfig = generationConfig;
184208
_safetySettings = safetySettings;
185209

@@ -521,15 +545,6 @@ public async Task<CountTokensResponse> CountTokens(List<IPart>? parts)
521545
return await CountTokens(request);
522546
}
523547

524-
/// <summary>
525-
/// Returns the name of the model.
526-
/// </summary>
527-
/// <returns>Name of the model.</returns>
528-
public string Name()
529-
{
530-
return _model;
531-
}
532-
533548
// Todo: Implementation missing
534549
/// <summary>
535550
/// Starts a chat session.
@@ -627,5 +642,79 @@ internal JsonSerializerOptions DefaultJsonSerializerOptions()
627642

628643
return options;
629644
}
645+
646+
private string GetAccessTokenFromAdc()
647+
{
648+
if (System.Runtime.InteropServices.RuntimeInformation.IsOSPlatform(System.Runtime.InteropServices.OSPlatform.Windows))
649+
{
650+
return RunExternalExe("cmd.exe", "/c gcloud auth application-default print-access-token").TrimEnd();
651+
}
652+
else
653+
{
654+
return RunExternalExe("gcloud", "auth application-default print-access-token").TrimEnd();
655+
}
656+
}
657+
658+
private string RunExternalExe(string filename, string arguments = null)
659+
{
660+
var process = new Process();
661+
662+
process.StartInfo.FileName = filename;
663+
if (!string.IsNullOrEmpty(arguments))
664+
{
665+
process.StartInfo.Arguments = arguments;
666+
}
667+
668+
process.StartInfo.CreateNoWindow = true;
669+
process.StartInfo.WindowStyle = ProcessWindowStyle.Hidden;
670+
process.StartInfo.UseShellExecute = false;
671+
672+
process.StartInfo.RedirectStandardError = true;
673+
process.StartInfo.RedirectStandardOutput = true;
674+
var stdOutput = new StringBuilder();
675+
process.OutputDataReceived += (sender, args) => stdOutput.AppendLine(args.Data); // Use AppendLine rather than Append since args.Data is one line of output, not including the newline character.
676+
677+
string stdError = null;
678+
try
679+
{
680+
process.Start();
681+
process.BeginOutputReadLine();
682+
stdError = process.StandardError.ReadToEnd();
683+
process.WaitForExit();
684+
}
685+
catch (Exception e)
686+
{
687+
throw new Exception("OS error while executing " + Format(filename, arguments)+ ": " + e.Message, e);
688+
}
689+
690+
if (process.ExitCode == 0)
691+
{
692+
return stdOutput.ToString();
693+
}
694+
else
695+
{
696+
var message = new StringBuilder();
697+
698+
if (!string.IsNullOrEmpty(stdError))
699+
{
700+
message.AppendLine(stdError);
701+
}
702+
703+
if (stdOutput.Length != 0)
704+
{
705+
message.AppendLine("Std output:");
706+
message.AppendLine(stdOutput.ToString());
707+
}
708+
709+
throw new Exception(Format(filename, arguments) + " finished with exit code = " + process.ExitCode + ": " + message);
710+
}
711+
}
712+
713+
private string Format(string filename, string arguments)
714+
{
715+
return "'" + filename +
716+
((string.IsNullOrEmpty(arguments)) ? string.Empty : " " + arguments) +
717+
"'";
718+
}
630719
}
631720
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#if NET472_OR_GREATER || NETSTANDARD2_0
2+
using System;
3+
using System.Linq;
4+
using System.Text.Json;
5+
#endif
6+
7+
namespace Mscc.GenerativeAI
8+
{
9+
public static class GenerativeModelExtensions
10+
{
11+
public static string? SanitizeModelName(this string? value)
12+
{
13+
if (value == null) return value;
14+
15+
if (value.StartsWith("model", StringComparison.InvariantCultureIgnoreCase))
16+
{
17+
var parts = value.Split(new char[] { '/' }, StringSplitOptions.RemoveEmptyEntries);
18+
value = parts.Last();
19+
}
20+
return value.ToLower();
21+
}
22+
23+
public static string? GetValue(this JsonElement element, string key)
24+
{
25+
if (key == null) throw new ArgumentNullException(nameof(key));
26+
27+
string result = null;
28+
if (element.TryGetProperty(key, out JsonElement value))
29+
{
30+
result = value.GetString();
31+
}
32+
33+
return result;
34+
}
35+
}
36+
}

src/Mscc.GenerativeAI/Mscc.GenerativeAI.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
<PackageLicenseFile>LICENSE</PackageLicenseFile>
2323
<PackageRequireLicenseAcceptance>False</PackageRequireLicenseAcceptance>
2424
<PackageReleaseNotes>
25+
- use EnvVars for default values (parameterless constructor)
2526
- improve Function Calling
2627
- improve Chat streaming
2728
- improve Embeddings
@@ -30,7 +31,7 @@
3031
</PropertyGroup>
3132

3233
<PropertyGroup Condition="$(TargetFramework.StartsWith('net6.0')) or $(TargetFramework.StartsWith('net7.0')) or $(TargetFramework.StartsWith('net8.0'))">
33-
<IsTrimmable>true</IsTrimmable>
34+
<IsTrimmable>false</IsTrimmable>
3435
<ImplicitUsings>enable</ImplicitUsings>
3536
<Nullable>enable</Nullable>
3637
</PropertyGroup>

tests/Mscc.GenerativeAI.Google/GenerativeModelGoogle_Should.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public void Initiate_Default_Model()
101101

102102
// Assert
103103
model.Should().NotBeNull();
104-
model.Name().Should().Be(Model.Gemini10Pro);
104+
model.Name.Should().Be(Model.Gemini10Pro);
105105
}
106106

107107
[Theory]
@@ -121,7 +121,7 @@ public void Initiate_Model(string expected)
121121

122122
// Assert
123123
model.Should().NotBeNull();
124-
model.Name().Should().Be(expected);
124+
model.Name.Should().Be(expected);
125125
}
126126

127127
[Fact]

tests/Mscc.GenerativeAI.Google/Test.Mscc.GenerativeAI.Google.csproj

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,12 @@
3636
<None Update="payload\scones.jpg">
3737
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
3838
</None>
39-
<None Remove="client_secret.json" />
4039
<Content Include="client_secret.json">
4140
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
4241
</Content>
43-
<None Remove="key.p12" />
4442
<Content Include="key.p12">
4543
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
4644
</Content>
47-
<None Remove="tokens.json" />
4845
<Content Include="tokens.json">
4946
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
5047
</Content>

tests/Mscc.GenerativeAI/GoogleAi_Embedding_Should.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public void Initialize_Model()
2929

3030
// Assert
3131
model.Should().NotBeNull();
32-
model.Name().Should().Be(expected);
32+
model.Name.Should().Be(expected);
3333
}
3434

3535
[Fact]

tests/Mscc.GenerativeAI/GoogleAi_GeminiProVision_Should.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public void Initialize_GeminiProVision()
3434

3535
// Assert
3636
model.Should().NotBeNull();
37-
model.Name().Should().Be(Model.GeminiProVision);
37+
model.Name.Should().Be(Model.GeminiProVision);
3838
}
3939

4040
[Fact]

tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#if NET472_OR_GREATER || NETSTANDARD2_0
2+
using System;
23
using System.Collections.Generic;
34
#endif
45
using FluentAssertions;
@@ -13,14 +14,28 @@ public class GoogleAi_GeminiPro_Should
1314
{
1415
private readonly ITestOutputHelper output;
1516
private readonly ConfigurationFixture fixture;
16-
private readonly string model = Model.GeminiPro;
17+
private readonly string model = Model.Gemini10Pro;
1718

1819
public GoogleAi_GeminiPro_Should(ITestOutputHelper output, ConfigurationFixture fixture)
1920
{
2021
this.output = output;
2122
this.fixture = fixture;
2223
}
2324

25+
[Fact]
26+
public void Initialize_EnvVars()
27+
{
28+
// Arrange
29+
Environment.SetEnvironmentVariable("GOOGLE_API_KEY", fixture.ApiKey);
30+
31+
// Act
32+
var model = new GenerativeModel();
33+
34+
// Assert
35+
model.Should().NotBeNull();
36+
model.Name.Should().Be(Model.Gemini10Pro);
37+
}
38+
2439
[Fact]
2540
public void Initialize_Default_Model()
2641
{
@@ -31,7 +46,7 @@ public void Initialize_Default_Model()
3146

3247
// Assert
3348
model.Should().NotBeNull();
34-
model.Name().Should().Be(Model.GeminiPro);
49+
model.Name.Should().Be(Model.Gemini10Pro);
3550
}
3651

3752
[Fact]
@@ -44,7 +59,7 @@ public void Initialize_Model()
4459

4560
// Assert
4661
model.Should().NotBeNull();
47-
model.Name().Should().Be(Model.GeminiPro);
62+
model.Name.Should().Be(Model.Gemini10Pro);
4863
}
4964

5065
[Fact]

0 commit comments

Comments
 (0)