1
+ // Licensed to the .NET Foundation under one or more agreements.
2
+ // The .NET Foundation licenses this file to you under the MIT license.
3
+ // See the LICENSE file in the project root for more information.
4
+
5
+ using System . Diagnostics ;
6
+ using System . Net . Http . Json ;
7
+ using Microsoft . Extensions . Logging ;
8
+
9
+ namespace DevProxy . Abstractions . LanguageModel ;
10
+
11
+ public class LMStudioLanguageModelClient ( LanguageModelConfiguration ? configuration , ILogger logger ) : ILanguageModelClient
12
+ {
13
+ private readonly LanguageModelConfiguration ? _configuration = configuration ;
14
+ private readonly ILogger _logger = logger ;
15
+ private bool ? _lmAvailable ;
16
+ private readonly Dictionary < string , OpenAICompletionResponse > _cacheCompletion = [ ] ;
17
+ private readonly Dictionary < ILanguageModelChatCompletionMessage [ ] , OpenAIChatCompletionResponse > _cacheChatCompletion = [ ] ;
18
+
19
+ public async Task < bool > IsEnabledAsync ( )
20
+ {
21
+ if ( _lmAvailable . HasValue )
22
+ {
23
+ return _lmAvailable . Value ;
24
+ }
25
+
26
+ _lmAvailable = await IsEnabledInternalAsync ( ) ;
27
+ return _lmAvailable . Value ;
28
+ }
29
+
30
+ private async Task < bool > IsEnabledInternalAsync ( )
31
+ {
32
+ if ( _configuration is null || ! _configuration . Enabled )
33
+ {
34
+ return false ;
35
+ }
36
+
37
+ if ( string . IsNullOrEmpty ( _configuration . Url ) )
38
+ {
39
+ _logger . LogError ( "URL is not set. Language model will be disabled" ) ;
40
+ return false ;
41
+ }
42
+
43
+ if ( string . IsNullOrEmpty ( _configuration . Model ) )
44
+ {
45
+ _logger . LogError ( "Model is not set. Language model will be disabled" ) ;
46
+ return false ;
47
+ }
48
+
49
+ _logger . LogDebug ( "Checking LM availability at {url}..." , _configuration . Url ) ;
50
+
51
+ try
52
+ {
53
+ // check if lm is on
54
+ using var client = new HttpClient ( ) ;
55
+ var response = await client . GetAsync ( $ "{ _configuration . Url } /v1/models") ;
56
+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
57
+
58
+ if ( ! response . IsSuccessStatusCode )
59
+ {
60
+ return false ;
61
+ }
62
+
63
+ var testCompletion = await GenerateCompletionInternalAsync ( "Are you there? Reply with a yes or no." ) ;
64
+ if ( testCompletion ? . Error is not null )
65
+ {
66
+ _logger . LogError ( "Error: {error}. Param: {param}" , testCompletion . Error . Message , testCompletion . Error . Param ) ;
67
+ return false ;
68
+ }
69
+
70
+ return true ;
71
+ }
72
+ catch ( Exception ex )
73
+ {
74
+ _logger . LogError ( ex , "Couldn't reach language model at {url}" , _configuration . Url ) ;
75
+ return false ;
76
+ }
77
+ }
78
+
79
+ public async Task < ILanguageModelCompletionResponse ? > GenerateCompletionAsync ( string prompt , CompletionOptions ? options = null )
80
+ {
81
+ using var scope = _logger . BeginScope ( nameof ( LMStudioLanguageModelClient ) ) ;
82
+
83
+ if ( _configuration is null )
84
+ {
85
+ return null ;
86
+ }
87
+
88
+ if ( ! _lmAvailable . HasValue )
89
+ {
90
+ _logger . LogError ( "Language model availability is not checked. Call {isEnabled} first." , nameof ( IsEnabledAsync ) ) ;
91
+ return null ;
92
+ }
93
+
94
+ if ( ! _lmAvailable . Value )
95
+ {
96
+ return null ;
97
+ }
98
+
99
+ if ( _configuration . CacheResponses && _cacheCompletion . TryGetValue ( prompt , out var cachedResponse ) )
100
+ {
101
+ _logger . LogDebug ( "Returning cached response for prompt: {prompt}" , prompt ) ;
102
+ return cachedResponse ;
103
+ }
104
+
105
+ var response = await GenerateCompletionInternalAsync ( prompt , options ) ;
106
+ if ( response == null )
107
+ {
108
+ return null ;
109
+ }
110
+ if ( response . Error is not null )
111
+ {
112
+ _logger . LogError ( "Error: {error}. Param: {param}" , response . Error . Message , response . Error . Param ) ;
113
+ return null ;
114
+ }
115
+ else
116
+ {
117
+ if ( _configuration . CacheResponses && response . Response is not null )
118
+ {
119
+ _cacheCompletion [ prompt ] = response ;
120
+ }
121
+
122
+ return response ;
123
+ }
124
+ }
125
+
126
+ private async Task < OpenAICompletionResponse ? > GenerateCompletionInternalAsync ( string prompt , CompletionOptions ? options = null )
127
+ {
128
+ Debug . Assert ( _configuration != null , "Configuration is null" ) ;
129
+
130
+ try
131
+ {
132
+ using var client = new HttpClient ( ) ;
133
+ var url = $ "{ _configuration . Url } /v1/completions";
134
+ _logger . LogDebug ( "Requesting completion. Prompt: {prompt}" , prompt ) ;
135
+
136
+ var response = await client . PostAsJsonAsync ( url ,
137
+ new
138
+ {
139
+ prompt ,
140
+ model = _configuration . Model ,
141
+ stream = false ,
142
+ temperature = options ? . Temperature ?? 0.8 ,
143
+ }
144
+ ) ;
145
+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
146
+
147
+ var res = await response . Content . ReadFromJsonAsync < OpenAICompletionResponse > ( ) ;
148
+ if ( res is null )
149
+ {
150
+ return res ;
151
+ }
152
+ res . RequestUrl = url ;
153
+ return res ;
154
+ }
155
+ catch ( Exception ex )
156
+ {
157
+ _logger . LogError ( ex , "Failed to generate completion" ) ;
158
+ return null ;
159
+ }
160
+ }
161
+
162
+ public async Task < ILanguageModelCompletionResponse ? > GenerateChatCompletionAsync ( ILanguageModelChatCompletionMessage [ ] messages )
163
+ {
164
+ using var scope = _logger . BeginScope ( nameof ( LMStudioLanguageModelClient ) ) ;
165
+
166
+ if ( _configuration is null )
167
+ {
168
+ return null ;
169
+ }
170
+
171
+ if ( ! _lmAvailable . HasValue )
172
+ {
173
+ _logger . LogError ( "Language model availability is not checked. Call {isEnabled} first." , nameof ( IsEnabledAsync ) ) ;
174
+ return null ;
175
+ }
176
+
177
+ if ( ! _lmAvailable . Value )
178
+ {
179
+ return null ;
180
+ }
181
+
182
+ if ( _configuration . CacheResponses && _cacheChatCompletion . TryGetValue ( messages , out var cachedResponse ) )
183
+ {
184
+ _logger . LogDebug ( "Returning cached response for message: {lastMessage}" , messages . Last ( ) . Content ) ;
185
+ return cachedResponse ;
186
+ }
187
+
188
+ var response = await GenerateChatCompletionInternalAsync ( messages ) ;
189
+ if ( response == null )
190
+ {
191
+ return null ;
192
+ }
193
+ if ( response . Error is not null )
194
+ {
195
+ _logger . LogError ( "Error: {error}. Param: {param}" , response . Error . Message , response . Error . Param ) ;
196
+ return null ;
197
+ }
198
+ else
199
+ {
200
+ if ( _configuration . CacheResponses && response . Response is not null )
201
+ {
202
+ _cacheChatCompletion [ messages ] = response ;
203
+ }
204
+
205
+ return response ;
206
+ }
207
+ }
208
+
209
+ private async Task < OpenAIChatCompletionResponse ? > GenerateChatCompletionInternalAsync ( ILanguageModelChatCompletionMessage [ ] messages )
210
+ {
211
+ Debug . Assert ( _configuration != null , "Configuration is null" ) ;
212
+
213
+ try
214
+ {
215
+ using var client = new HttpClient ( ) ;
216
+ var url = $ "{ _configuration . Url } /v1/chat/completions";
217
+ _logger . LogDebug ( "Requesting chat completion. Message: {lastMessage}" , messages . Last ( ) . Content ) ;
218
+
219
+ var response = await client . PostAsJsonAsync ( url ,
220
+ new
221
+ {
222
+ messages ,
223
+ model = _configuration . Model ,
224
+ stream = false
225
+ }
226
+ ) ;
227
+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
228
+
229
+ var res = await response . Content . ReadFromJsonAsync < OpenAIChatCompletionResponse > ( ) ;
230
+ if ( res is null )
231
+ {
232
+ return res ;
233
+ }
234
+
235
+ res . RequestUrl = url ;
236
+ return res ;
237
+ }
238
+ catch ( Exception ex )
239
+ {
240
+ _logger . LogError ( ex , "Failed to generate chat completion" ) ;
241
+ return null ;
242
+ }
243
+ }
244
+ }
245
+
246
+ internal static class CacheChatCompletionExtensions
247
+ {
248
+ public static OpenAIChatCompletionMessage [ ] ? GetKey (
249
+ this Dictionary < OpenAIChatCompletionMessage [ ] , OpenAIChatCompletionResponse > cache ,
250
+ ILanguageModelChatCompletionMessage [ ] messages )
251
+ {
252
+ return cache . Keys . FirstOrDefault ( k => k . SequenceEqual ( messages ) ) ;
253
+ }
254
+
255
+ public static bool TryGetValue (
256
+ this Dictionary < OpenAIChatCompletionMessage [ ] , OpenAIChatCompletionResponse > cache ,
257
+ ILanguageModelChatCompletionMessage [ ] messages , out OpenAIChatCompletionResponse ? value )
258
+ {
259
+ var key = cache . GetKey ( messages ) ;
260
+ if ( key is null )
261
+ {
262
+ value = null ;
263
+ return false ;
264
+ }
265
+
266
+ value = cache [ key ] ;
267
+ return true ;
268
+ }
269
+ }
0 commit comments