Skip to content

Commit df4d734

Browse files
ability to rewind chat history; improve handling of chat history (streaming)
1 parent 35763c2 commit df4d734

File tree

3 files changed

+156
-14
lines changed

3 files changed

+156
-14
lines changed

src/Mscc.GenerativeAI/CHANGELOG.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
### Changed
12+
### Fixed
13+
14+
## 0.8.2
15+
1016
### Added
1117

18+
- ability to rewind chat history
1219
- provide types to simplify creation of tuned model
1320
- compatibility methods for PaLM models
21+
- access text of content response easier
1422

1523
### Changed
16-
### Fixed
24+
25+
- improve handling of chat history (streaming)
1726

1827
## 0.8.1
1928

src/Mscc.GenerativeAI/Types/ChatSession.cs

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,40 @@
66
using System.Threading.Tasks;
77
#endif
88
using System.Runtime.CompilerServices;
9+
using System.Text;
910

1011
namespace Mscc.GenerativeAI
1112
{
13+
/// <summary>
14+
/// This ChatSession object collects the messages sent and received, in its ChatSession.History attribute.
15+
/// </summary>
1216
public class ChatSession
1317
{
1418
private readonly GenerativeModel _model;
1519
private readonly GenerationConfig? _generationConfig;
1620
private readonly List<SafetySetting>? _safetySettings;
1721
private readonly List<Tool>? _tools;
22+
private ContentResponse? _lastSent;
23+
private ContentResponse? _lastReceived;
1824

25+
/// <summary>
26+
/// The chat history.
27+
/// </summary>
1928
public List<ContentResponse> History { get; set; }
2029

2130
/// <summary>
22-
///
31+
/// Returns the last received ContentResponse
32+
/// </summary>
33+
public ContentResponse? Last => _lastReceived;
34+
35+
/// <summary>
36+
/// Constructor to start a chat session with history.
2337
/// </summary>
2438
/// <param name="model"></param>
2539
/// <param name="history"></param>
40+
/// <param name="generationConfig">Optional. Configuration options for model generation and outputs.</param>
41+
/// <param name="safetySettings">Optional. A list of unique SafetySetting instances for blocking unsafe content.</param>
42+
/// <param name="tools">Optional. </param>
2643
public ChatSession(GenerativeModel model,
2744
List<ContentResponse>? history = null,
2845
GenerationConfig? generationConfig = null,
@@ -37,38 +54,55 @@ public ChatSession(GenerativeModel model,
3754
}
3855

3956
/// <summary>
40-
///
57+
/// Sends the conversation history with the added message and returns the model's response.
58+
/// Appends the request and response to the conversation history.
4159
/// </summary>
4260
/// <param name="prompt"></param>
61+
/// <param name="generationConfig">Optional. Overrides for the model's generation config.</param>
62+
/// <param name="safetySettings">Optional. Overrides for the model's safety settings.</param>
4363
/// <returns></returns>
44-
public async Task<GenerateContentResponse> SendMessage(string prompt)
64+
public async Task<GenerateContentResponse> SendMessage(string prompt,
65+
GenerationConfig? generationConfig = null,
66+
List<SafetySetting>? safetySettings = null)
4567
{
4668
if (prompt == null) throw new ArgumentNullException(nameof(prompt));
4769
if (string.IsNullOrEmpty(prompt)) throw new ArgumentException(prompt, nameof(prompt));
4870

49-
History.Add(new ContentResponse { Role = Role.User, Parts = new List<Part> { new Part { Text = prompt } } });
71+
_lastSent = new ContentResponse
72+
{
73+
Role = Role.User, Parts = new List<Part> { new Part { Text = prompt } }
74+
};
75+
History.Add(_lastSent);
5076
var request = new GenerateContentRequest
5177
{
5278
Contents = History.Select(x =>
5379
new Content { Role = x.Role, PartTypes = x.Parts }
5480
).ToList(),
55-
GenerationConfig = _generationConfig,
56-
SafetySettings = _safetySettings,
81+
GenerationConfig = generationConfig ?? _generationConfig,
82+
SafetySettings = safetySettings ?? _safetySettings,
5783
Tools = _tools
5884
};
5985

6086
var response = await _model.GenerateContent(request);
61-
History.Add(new ContentResponse { Role = Role.Model, Parts = new List<Part> { new Part { Text = response.Text } } });
87+
_lastReceived = new ContentResponse
88+
{
89+
Role = Role.Model, Parts = new List<Part> { new Part { Text = response.Text ?? string.Empty } }
90+
};
91+
History.Add(_lastReceived);
6292
return response;
6393
}
6494

6595
/// <summary>
6696
///
6797
/// </summary>
6898
/// <param name="content"></param>
99+
/// <param name="generationConfig">Optional. Overrides for the model's generation config.</param>
100+
/// <param name="safetySettings">Optional. Overrides for the model's safety settings.</param>
69101
/// <param name="cancellationToken"></param>
70102
/// <returns></returns>
71-
public async IAsyncEnumerable<GenerateContentResponse> SendMessageStream(object content,
103+
public async IAsyncEnumerable<GenerateContentResponse> SendMessageStream(object content,
104+
GenerationConfig? generationConfig = null,
105+
List<SafetySetting>? safetySettings = null,
72106
[EnumeratorCancellation] CancellationToken cancellationToken = default)
73107
{
74108
if (content == null) throw new ArgumentNullException(nameof(content));
@@ -87,25 +121,60 @@ public async IAsyncEnumerable<GenerateContentResponse> SendMessageStream(object
87121
parts = contentParts;
88122
}
89123

90-
History.Add(new ContentResponse { Role = role, Parts = parts });
124+
_lastSent = new ContentResponse
125+
{
126+
Role = role, Parts = parts
127+
};
128+
History.Add(_lastSent);
91129
var request = new GenerateContentRequest
92130
{
93131
Contents = History.Select(x =>
94132
new Content { Role = x.Role, PartTypes = x.Parts }
95133
).ToList(),
96-
GenerationConfig = _generationConfig,
97-
SafetySettings = _safetySettings,
134+
GenerationConfig = generationConfig ?? _generationConfig,
135+
SafetySettings = safetySettings ?? _safetySettings,
98136
Tools = _tools
99137
};
100138

139+
var fullText = new StringBuilder();
101140
var response = _model.GenerateContentStream(request, cancellationToken);
102141
await foreach (var item in response)
103142
{
104143
if (cancellationToken.IsCancellationRequested)
105144
yield break;
106-
History.Add(new ContentResponse { Role = Role.Model, Parts = item?.Candidates?[0]?.Content?.Parts });
145+
fullText.Append(item.Text);
107146
yield return item;
108147
}
148+
_lastReceived = new ContentResponse
149+
{
150+
Role = Role.Model, Parts = new List<Part> { new Part { Text = fullText.ToString() } }
151+
};
152+
History.Add(_lastReceived);
153+
}
154+
155+
/// <summary>
156+
/// Removes the last request/response pair from the chat history.
157+
/// </summary>
158+
/// <returns>Tuple with the last request/response pair.</returns>
159+
public (ContentResponse? Sent, ContentResponse? Received) Rewind()
160+
{
161+
(ContentResponse? Sent, ContentResponse? Received) result;
162+
var position = History.Count - 2;
163+
164+
if (_lastReceived is null)
165+
{
166+
var entries = History.GetRange(position, 2);
167+
result = (entries.FirstOrDefault(), entries.LastOrDefault());
168+
}
169+
else
170+
{
171+
result = (_lastSent, _lastReceived);
172+
_lastSent = null;
173+
_lastReceived = null;
174+
}
175+
176+
History.RemoveRange(position, 2);
177+
return result;
109178
}
110179
}
111180
}

tests/Mscc.GenerativeAI/GoogleAi_GeminiPro_Should.cs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,66 @@ public async void Start_Chat_Conversations()
458458
chat.History.ForEach(c =>
459459
{
460460
output.WriteLine($"{new string('-', 20)}");
461-
output.WriteLine($"{c.Role}: {c.Parts[0].Text}");
461+
output.WriteLine($"{c.Role}: {c.Text}");
462462
});
463463
}
464464

465+
[Fact]
466+
// Refs:
467+
// https://ai.google.dev/tutorials/python_quickstart#chat_conversations
468+
public async void Start_Chat_Rewind_Conversation()
469+
{
470+
// Arrange
471+
var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model);
472+
var chat = model.StartChat();
473+
_ = await chat.SendMessage("Hello, fancy brainstorming about IT?");
474+
_ = await chat.SendMessage("In one sentence, explain how a computer works to a young child.");
475+
_ = await chat.SendMessage("Okay, how about a more detailed explanation to a high school kid?");
476+
_ = await chat.SendMessage("Lastly, give a thorough definition for a CS graduate.");
477+
478+
// Act
479+
var entries = chat.Rewind();
480+
481+
// Assert
482+
entries.Should().NotBeNull();
483+
entries.Sent.Should().NotBeNull();
484+
entries.Received.Should().NotBeNull();
485+
output.WriteLine("------ Rewind ------");
486+
output.WriteLine($"{entries.Sent.Role}: {entries.Sent.Text}");
487+
output.WriteLine($"{new string('-', 20)}");
488+
output.WriteLine($"{entries.Received.Role}: {entries.Received.Text}");
489+
output.WriteLine($"{new string('-', 20)}");
490+
491+
chat.History.Count.Should().Be(6);
492+
output.WriteLine("------ History -----");
493+
chat.History.ForEach(c =>
494+
{
495+
output.WriteLine($"{new string('-', 20)}");
496+
output.WriteLine($"{c.Role}: {c.Text}");
497+
});
498+
}
499+
500+
[Fact]
501+
// Refs:
502+
// https://ai.google.dev/tutorials/python_quickstart#chat_conversations
503+
public async void Start_Chat_Conversations_Get_Last()
504+
{
505+
// Arrange
506+
var model = new GenerativeModel(apiKey: fixture.ApiKey, model: this.model);
507+
var chat = model.StartChat();
508+
_ = await chat.SendMessage("Hello, fancy brainstorming about IT?");
509+
_ = await chat.SendMessage("In one sentence, explain how a computer works to a young child.");
510+
_ = await chat.SendMessage("Okay, how about a more detailed explanation to a high school kid?");
511+
_ = await chat.SendMessage("Lastly, give a thorough definition for a CS graduate.");
512+
513+
// Act
514+
var sut = chat.Last;
515+
516+
// Assert
517+
sut.Should().NotBeNull();
518+
output.WriteLine($"{sut.Role}: {sut.Text}");
519+
}
520+
465521
[Fact]
466522
public async void Start_Chat_Streaming()
467523
{
@@ -486,6 +542,14 @@ public async void Start_Chat_Streaming()
486542
// output.WriteLine($"CandidatesTokenCount: {response?.UsageMetadata?.CandidatesTokenCount}");
487543
// output.WriteLine($"TotalTokenCount: {response?.UsageMetadata?.TotalTokenCount}");
488544
}
545+
chat.History.Count.Should().Be(2);
546+
output.WriteLine($"{new string('-', 20)}");
547+
output.WriteLine("------ History -----");
548+
chat.History.ForEach(c =>
549+
{
550+
output.WriteLine($"{new string('-', 20)}");
551+
output.WriteLine($"{c.Role}: {c.Text}");
552+
});
489553
}
490554

491555
[Fact]

0 commit comments

Comments
 (0)