Skip to content

Commit a60dab1

Browse files
waldekmastykarzgarrytrinder
authored andcommitted
Refactors local LM connection to make it generic. Closes #1105
1 parent e9571e1 commit a60dab1

File tree

9 files changed

+561
-246
lines changed

9 files changed

+561
-246
lines changed

dev-proxy-abstractions/LanguageModel/ILanguageModelCompletionResponse.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ namespace DevProxy.Abstractions.LanguageModel;
66

77
public interface ILanguageModelCompletionResponse
88
{
9-
string? Error { get; set; }
10-
string? Response { get; set; }
9+
string? ErrorMessage { get; }
10+
string? Response { get; }
11+
// custom property added to log in the mock output
12+
string? RequestUrl { get; set; }
13+
14+
OpenAIResponse ConvertToOpenAIResponse();
1115
}
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Diagnostics;
6+
using System.Net.Http.Json;
7+
using Microsoft.Extensions.Logging;
8+
9+
namespace DevProxy.Abstractions.LanguageModel;
10+
11+
public class LMStudioLanguageModelClient(LanguageModelConfiguration? configuration, ILogger logger) : ILanguageModelClient
12+
{
13+
private readonly LanguageModelConfiguration? _configuration = configuration;
14+
private readonly ILogger _logger = logger;
15+
private bool? _lmAvailable;
16+
private readonly Dictionary<string, OpenAICompletionResponse> _cacheCompletion = [];
17+
private readonly Dictionary<ILanguageModelChatCompletionMessage[], OpenAIChatCompletionResponse> _cacheChatCompletion = [];
18+
19+
public async Task<bool> IsEnabledAsync()
20+
{
21+
if (_lmAvailable.HasValue)
22+
{
23+
return _lmAvailable.Value;
24+
}
25+
26+
_lmAvailable = await IsEnabledInternalAsync();
27+
return _lmAvailable.Value;
28+
}
29+
30+
private async Task<bool> IsEnabledInternalAsync()
31+
{
32+
if (_configuration is null || !_configuration.Enabled)
33+
{
34+
return false;
35+
}
36+
37+
if (string.IsNullOrEmpty(_configuration.Url))
38+
{
39+
_logger.LogError("URL is not set. Language model will be disabled");
40+
return false;
41+
}
42+
43+
if (string.IsNullOrEmpty(_configuration.Model))
44+
{
45+
_logger.LogError("Model is not set. Language model will be disabled");
46+
return false;
47+
}
48+
49+
_logger.LogDebug("Checking LM availability at {url}...", _configuration.Url);
50+
51+
try
52+
{
53+
// check if lm is on
54+
using var client = new HttpClient();
55+
var response = await client.GetAsync($"{_configuration.Url}/v1/models");
56+
_logger.LogDebug("Response: {response}", response.StatusCode);
57+
58+
if (!response.IsSuccessStatusCode)
59+
{
60+
return false;
61+
}
62+
63+
var testCompletion = await GenerateCompletionInternalAsync("Are you there? Reply with a yes or no.");
64+
if (testCompletion?.Error is not null)
65+
{
66+
_logger.LogError("Error: {error}. Param: {param}", testCompletion.Error.Message, testCompletion.Error.Param);
67+
return false;
68+
}
69+
70+
return true;
71+
}
72+
catch (Exception ex)
73+
{
74+
_logger.LogError(ex, "Couldn't reach language model at {url}", _configuration.Url);
75+
return false;
76+
}
77+
}
78+
79+
public async Task<ILanguageModelCompletionResponse?> GenerateCompletionAsync(string prompt, CompletionOptions? options = null)
80+
{
81+
using var scope = _logger.BeginScope(nameof(LMStudioLanguageModelClient));
82+
83+
if (_configuration is null)
84+
{
85+
return null;
86+
}
87+
88+
if (!_lmAvailable.HasValue)
89+
{
90+
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabledAsync));
91+
return null;
92+
}
93+
94+
if (!_lmAvailable.Value)
95+
{
96+
return null;
97+
}
98+
99+
if (_configuration.CacheResponses && _cacheCompletion.TryGetValue(prompt, out var cachedResponse))
100+
{
101+
_logger.LogDebug("Returning cached response for prompt: {prompt}", prompt);
102+
return cachedResponse;
103+
}
104+
105+
var response = await GenerateCompletionInternalAsync(prompt, options);
106+
if (response == null)
107+
{
108+
return null;
109+
}
110+
if (response.Error is not null)
111+
{
112+
_logger.LogError("Error: {error}. Param: {param}", response.Error.Message, response.Error.Param);
113+
return null;
114+
}
115+
else
116+
{
117+
if (_configuration.CacheResponses && response.Response is not null)
118+
{
119+
_cacheCompletion[prompt] = response;
120+
}
121+
122+
return response;
123+
}
124+
}
125+
126+
private async Task<OpenAICompletionResponse?> GenerateCompletionInternalAsync(string prompt, CompletionOptions? options = null)
127+
{
128+
Debug.Assert(_configuration != null, "Configuration is null");
129+
130+
try
131+
{
132+
using var client = new HttpClient();
133+
var url = $"{_configuration.Url}/v1/completions";
134+
_logger.LogDebug("Requesting completion. Prompt: {prompt}", prompt);
135+
136+
var response = await client.PostAsJsonAsync(url,
137+
new
138+
{
139+
prompt,
140+
model = _configuration.Model,
141+
stream = false,
142+
temperature = options?.Temperature ?? 0.8,
143+
}
144+
);
145+
_logger.LogDebug("Response: {response}", response.StatusCode);
146+
147+
var res = await response.Content.ReadFromJsonAsync<OpenAICompletionResponse>();
148+
if (res is null)
149+
{
150+
return res;
151+
}
152+
res.RequestUrl = url;
153+
return res;
154+
}
155+
catch (Exception ex)
156+
{
157+
_logger.LogError(ex, "Failed to generate completion");
158+
return null;
159+
}
160+
}
161+
162+
public async Task<ILanguageModelCompletionResponse?> GenerateChatCompletionAsync(ILanguageModelChatCompletionMessage[] messages)
163+
{
164+
using var scope = _logger.BeginScope(nameof(LMStudioLanguageModelClient));
165+
166+
if (_configuration is null)
167+
{
168+
return null;
169+
}
170+
171+
if (!_lmAvailable.HasValue)
172+
{
173+
_logger.LogError("Language model availability is not checked. Call {isEnabled} first.", nameof(IsEnabledAsync));
174+
return null;
175+
}
176+
177+
if (!_lmAvailable.Value)
178+
{
179+
return null;
180+
}
181+
182+
if (_configuration.CacheResponses && _cacheChatCompletion.TryGetValue(messages, out var cachedResponse))
183+
{
184+
_logger.LogDebug("Returning cached response for message: {lastMessage}", messages.Last().Content);
185+
return cachedResponse;
186+
}
187+
188+
var response = await GenerateChatCompletionInternalAsync(messages);
189+
if (response == null)
190+
{
191+
return null;
192+
}
193+
if (response.Error is not null)
194+
{
195+
_logger.LogError("Error: {error}. Param: {param}", response.Error.Message, response.Error.Param);
196+
return null;
197+
}
198+
else
199+
{
200+
if (_configuration.CacheResponses && response.Response is not null)
201+
{
202+
_cacheChatCompletion[messages] = response;
203+
}
204+
205+
return response;
206+
}
207+
}
208+
209+
private async Task<OpenAIChatCompletionResponse?> GenerateChatCompletionInternalAsync(ILanguageModelChatCompletionMessage[] messages)
210+
{
211+
Debug.Assert(_configuration != null, "Configuration is null");
212+
213+
try
214+
{
215+
using var client = new HttpClient();
216+
var url = $"{_configuration.Url}/v1/chat/completions";
217+
_logger.LogDebug("Requesting chat completion. Message: {lastMessage}", messages.Last().Content);
218+
219+
var response = await client.PostAsJsonAsync(url,
220+
new
221+
{
222+
messages,
223+
model = _configuration.Model,
224+
stream = false
225+
}
226+
);
227+
_logger.LogDebug("Response: {response}", response.StatusCode);
228+
229+
var res = await response.Content.ReadFromJsonAsync<OpenAIChatCompletionResponse>();
230+
if (res is null)
231+
{
232+
return res;
233+
}
234+
235+
res.RequestUrl = url;
236+
return res;
237+
}
238+
catch (Exception ex)
239+
{
240+
_logger.LogError(ex, "Failed to generate chat completion");
241+
return null;
242+
}
243+
}
244+
}
245+
246+
internal static class CacheChatCompletionExtensions
247+
{
248+
public static OpenAIChatCompletionMessage[]? GetKey(
249+
this Dictionary<OpenAIChatCompletionMessage[], OpenAIChatCompletionResponse> cache,
250+
ILanguageModelChatCompletionMessage[] messages)
251+
{
252+
return cache.Keys.FirstOrDefault(k => k.SequenceEqual(messages));
253+
}
254+
255+
public static bool TryGetValue(
256+
this Dictionary<OpenAIChatCompletionMessage[], OpenAIChatCompletionResponse> cache,
257+
ILanguageModelChatCompletionMessage[] messages, out OpenAIChatCompletionResponse? value)
258+
{
259+
var key = cache.GetKey(messages);
260+
if (key is null)
261+
{
262+
value = null;
263+
return false;
264+
}
265+
266+
value = cache[key];
267+
return true;
268+
}
269+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Extensions.Logging;
6+
7+
namespace DevProxy.Abstractions.LanguageModel;
8+
9+
public static class LanguageModelClientFactory
10+
{
11+
public static ILanguageModelClient Create(LanguageModelConfiguration? config, ILogger logger)
12+
{
13+
return config?.Client switch
14+
{
15+
LanguageModelClient.LMStudio => new LMStudioLanguageModelClient(config, logger),
16+
LanguageModelClient.Ollama => new OllamaLanguageModelClient(config, logger),
17+
_ => new OllamaLanguageModelClient(config, logger)
18+
};
19+
}
20+
}

dev-proxy-abstractions/LanguageModel/LanguageModelConfiguration.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44

55
namespace DevProxy.Abstractions.LanguageModel;
66

7+
public enum LanguageModelClient
8+
{
9+
LMStudio,
10+
Ollama
11+
}
12+
713
public class LanguageModelConfiguration
814
{
15+
public bool CacheResponses { get; set; } = true;
916
public bool Enabled { get; set; } = false;
17+
public LanguageModelClient Client { get; set; } = LanguageModelClient.Ollama;
1018
// default Ollama URL
11-
public string? Url { get; set; } = "http://localhost:11434";
1219
public string? Model { get; set; } = "llama3.2";
13-
public bool CacheResponses { get; set; } = true;
20+
public string? Url { get; set; } = "http://localhost:11434";
1421
}

dev-proxy-abstractions/LanguageModel/OllamaLanguageModelClient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ private async Task<bool> IsEnabledInternalAsync()
244244
}
245245
}
246246

247-
internal static class CacheChatCompletionExtensions
247+
internal static class OllamaCacheChatCompletionExtensions
248248
{
249249
public static OllamaLanguageModelChatCompletionMessage[]? GetKey(
250250
this Dictionary<OllamaLanguageModelChatCompletionMessage[], OllamaLanguageModelChatCompletionResponse> cache,

0 commit comments

Comments
 (0)