Skip to content

Commit 191c87f

Browse files
committed
feat: consume last finish event which contains finish_reason
1 parent 0caf067 commit 191c87f

File tree

2 files changed

+69
-32
lines changed

2 files changed

+69
-32
lines changed

llm-bedrock/src/conversions.rs

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use aws_sdk_bedrockruntime::{
88
types::{
99
ContentBlockDeltaEvent, ContentBlockStartEvent, ConversationRole,
1010
ConverseStreamMetadataEvent, ConverseStreamOutput, ImageBlock, ImageFormat,
11-
InferenceConfiguration, SystemContentBlock, Tool, ToolConfiguration, ToolInputSchema,
12-
ToolSpecification, ToolUseBlock,
11+
InferenceConfiguration, MessageStopEvent, SystemContentBlock, Tool, ToolConfiguration,
12+
ToolInputSchema, ToolSpecification, ToolUseBlock,
1313
},
1414
};
1515
use golem_llm::golem::llm::llm;
@@ -433,67 +433,77 @@ fn bedrock_image_to_llm_content_part(block: bedrock::types::ImageBlock) -> llm::
433433

434434
pub fn converse_stream_output_to_stream_event(
435435
event: ConverseStreamOutput,
436-
) -> Option<Vec<llm::StreamEvent>> {
436+
) -> Option<llm::StreamEvent> {
437437
match event {
438438
ConverseStreamOutput::ContentBlockStart(block) => process_content_block_start_event(block),
439439
ConverseStreamOutput::ContentBlockDelta(block) => process_content_block_delta_event(block),
440440
ConverseStreamOutput::Metadata(metadata) => process_metadata_event(metadata),
441+
ConverseStreamOutput::MessageStop(event) => process_message_stop_event(event),
441442
_ => None,
442443
}
443444
}
444445

445-
fn process_content_block_start_event(
446-
block: ContentBlockStartEvent,
447-
) -> Option<Vec<llm::StreamEvent>> {
446+
fn process_content_block_start_event(block: ContentBlockStartEvent) -> Option<llm::StreamEvent> {
448447
if let Some(start_info) = block.start {
449448
if start_info.is_tool_use() {
450449
let tool_use = start_info.as_tool_use().unwrap().clone();
451-
return Some(vec![llm::StreamEvent::Delta(llm::StreamDelta {
450+
return Some(llm::StreamEvent::Delta(llm::StreamDelta {
452451
content: None,
453452
tool_calls: Some(vec![llm::ToolCall {
454453
id: tool_use.tool_use_id,
455454
name: tool_use.name,
456455
arguments_json: "".to_owned(),
457456
}]),
458-
})]);
457+
}));
459458
}
460459
}
461460
None
462461
}
463462

464-
fn process_content_block_delta_event(
465-
block: ContentBlockDeltaEvent,
466-
) -> Option<Vec<llm::StreamEvent>> {
463+
fn process_content_block_delta_event(block: ContentBlockDeltaEvent) -> Option<llm::StreamEvent> {
467464
if let Some(block_info) = block.delta {
468465
if block_info.is_tool_use() {
469466
let tool_use = block_info.as_tool_use().unwrap().clone();
470-
return Some(vec![llm::StreamEvent::Delta(llm::StreamDelta {
467+
return Some(llm::StreamEvent::Delta(llm::StreamDelta {
471468
content: None,
472469
tool_calls: Some(vec![llm::ToolCall {
473470
id: "".to_owned(),
474471
name: "".to_owned(),
475472
arguments_json: tool_use.input,
476473
}]),
477-
})]);
474+
}));
478475
} else if block_info.is_text() {
479476
let text = block_info.as_text().unwrap().clone();
480-
return Some(vec![llm::StreamEvent::Delta(llm::StreamDelta {
477+
return Some(llm::StreamEvent::Delta(llm::StreamDelta {
481478
content: Some(vec![llm::ContentPart::Text(text)]),
482479
tool_calls: None,
483-
})]);
480+
}));
484481
}
485482
}
486483
None
487484
}
488485

489-
fn process_metadata_event(metadata: ConverseStreamMetadataEvent) -> Option<Vec<llm::StreamEvent>> {
490-
Some(vec![llm::StreamEvent::Finish(llm::ResponseMetadata {
486+
fn process_metadata_event(metadata: ConverseStreamMetadataEvent) -> Option<llm::StreamEvent> {
487+
Some(llm::StreamEvent::Finish(llm::ResponseMetadata {
491488
finish_reason: None,
492489
timestamp: None,
493490
usage: metadata.usage().map(bedrock_usage_to_llm_usage),
494491
provider_id: Some("bedrock".to_owned()),
495492
provider_metadata_json: None,
496-
})])
493+
}))
494+
}
495+
496+
fn process_message_stop_event(event: MessageStopEvent) -> Option<llm::StreamEvent> {
497+
Some(llm::StreamEvent::Finish(llm::ResponseMetadata {
498+
finish_reason: Some(bedrock_stop_reason_to_finish_reason(event.stop_reason())),
499+
timestamp: None,
500+
usage: None,
501+
provider_id: None,
502+
provider_metadata_json: event
503+
.additional_model_response_fields
504+
.clone()
505+
.and_then(smithy_document_to_metadata_json),
506+
}))
497507
}
498508

499509
fn json_str_to_smithy_document(value: &str) -> Result<Document, llm::Error> {
@@ -595,3 +605,18 @@ pub fn custom_error(code: llm::ErrorCode, message: String) -> llm::Error {
595605
provider_error_json: None,
596606
}
597607
}
608+
609+
pub fn merge_metadata(
610+
mut metadata1: llm::ResponseMetadata,
611+
metadata2: llm::ResponseMetadata,
612+
) -> llm::ResponseMetadata {
613+
metadata1.usage = metadata1.usage.or(metadata2.usage);
614+
metadata1.timestamp = metadata1.timestamp.or(metadata2.timestamp);
615+
metadata1.provider_id = metadata1.provider_id.or(metadata2.provider_id);
616+
metadata1.finish_reason = metadata1.finish_reason.or(metadata2.finish_reason);
617+
metadata1.provider_metadata_json = metadata1
618+
.provider_metadata_json
619+
.or(metadata2.provider_metadata_json);
620+
621+
metadata1
622+
}

llm-bedrock/src/stream.rs

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::cell::{RefCell, RefMut};
77

88
use crate::{
99
client::get_async_runtime,
10-
conversions::{converse_stream_output_to_stream_event, custom_error},
10+
conversions::{converse_stream_output_to_stream_event, custom_error, merge_metadata},
1111
};
1212

1313
type BedrockEventSource =
@@ -51,18 +51,11 @@ impl BedrockChatStream {
5151
fn set_finished(&self) {
5252
*self.finished.borrow_mut() = true;
5353
}
54-
}
55-
56-
impl llm::GuestChatStream for BedrockChatStream {
57-
fn get_next(&self) -> Option<Vec<llm::StreamEvent>> {
58-
if self.is_finished() {
59-
return Some(vec![]);
60-
}
61-
54+
fn get_single_event(&self) -> Option<llm::StreamEvent> {
6255
if let Some(stream) = self.stream_mut().as_mut() {
6356
let runtime = get_async_runtime();
6457

65-
runtime.block_on(async {
58+
runtime.block_on(async move {
6659
let token = stream.recv().await;
6760

6861
match token {
@@ -72,24 +65,43 @@ impl llm::GuestChatStream for BedrockChatStream {
7265
}
7366
Ok(None) => {
7467
self.set_finished();
75-
Some(vec![])
68+
None
7669
}
7770
Err(error) => {
7871
self.set_finished();
79-
Some(vec![llm::StreamEvent::Error(custom_error(
72+
Some(llm::StreamEvent::Error(custom_error(
8073
llm::ErrorCode::InternalError,
8174
format!("An error occurred while reading event stream: {error}"),
82-
))])
75+
)))
8376
}
8477
}
8578
})
8679
} else if let Some(error) = self.failure() {
8780
self.set_finished();
88-
Some(vec![llm::StreamEvent::Error(error.clone())])
81+
Some(llm::StreamEvent::Error(error.clone()))
8982
} else {
9083
None
9184
}
9285
}
86+
}
87+
88+
impl llm::GuestChatStream for BedrockChatStream {
89+
fn get_next(&self) -> Option<Vec<llm::StreamEvent>> {
90+
if self.is_finished() {
91+
return Some(vec![]);
92+
}
93+
self.get_single_event().map(|event| {
94+
if let llm::StreamEvent::Finish(metadata) = event.clone() {
95+
if let Some(llm::StreamEvent::Finish(final_metadata)) = self.get_single_event() {
96+
return vec![llm::StreamEvent::Finish(merge_metadata(
97+
metadata,
98+
final_metadata,
99+
))];
100+
}
101+
}
102+
vec![event]
103+
})
104+
}
93105

94106
fn blocking_get_next(&self) -> Vec<llm::StreamEvent> {
95107
let mut result = Vec::new();

0 commit comments

Comments
 (0)