-
Notifications
You must be signed in to change notification settings - Fork 6
Доступ к id загруженных файлов + Исправление сохранения внутренних сообщений Spring AI в ChatResponseMetadata #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; | ||
import org.springframework.ai.chat.prompt.ChatOptions; | ||
import org.springframework.ai.chat.prompt.Prompt; | ||
import org.springframework.ai.content.Media; | ||
import org.springframework.ai.model.ModelOptionsUtils; | ||
import org.springframework.ai.model.tool.*; | ||
import org.springframework.ai.retry.RetryUtils; | ||
|
@@ -47,7 +48,8 @@ public class GigaChatModel implements ChatModel { | |
public static final String DEFAULT_MODEL_NAME = GigaChatApi.ChatModel.GIGA_CHAT_2.getName(); | ||
public static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = | ||
new DefaultChatModelObservationConvention(); | ||
public static final String CONVERSATION_HISTORY = "conversationHistory"; | ||
public static final String INTERNAL_CONVERSATION_HISTORY = "GigaChatInternalConversationHistory"; | ||
public static final String UPLOADED_MEDIA_IDS = "GigaChatUploadedMediaIds"; | ||
private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = | ||
ToolCallingManager.builder().build(); | ||
|
||
|
@@ -150,8 +152,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons | |
Usage accumulatedUsage = | ||
UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); | ||
|
||
ChatResponse chatResponse = | ||
toChatResponse(completionResponse, accumulatedUsage, false, prompt.getInstructions()); | ||
ChatResponse chatResponse = toChatResponse(completionResponse, accumulatedUsage, false); | ||
observationContext.setResponse(chatResponse); | ||
|
||
return chatResponse; | ||
|
@@ -172,7 +173,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons | |
} | ||
} | ||
|
||
return response; | ||
return buildChatResponseWithCustomMetadata(prompt, response); | ||
} | ||
|
||
@Override | ||
|
@@ -213,8 +214,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha | |
Usage accumulatedUsage = | ||
UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); | ||
|
||
ChatResponse chatResponse = | ||
toChatResponse(completionResponse, accumulatedUsage, true, prompt.getInstructions()); | ||
ChatResponse chatResponse = toChatResponse(completionResponse, accumulatedUsage, true); | ||
|
||
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired( | ||
prompt.getOptions(), chatResponse)) { | ||
|
@@ -233,7 +233,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha | |
} | ||
} | ||
|
||
return Flux.just(chatResponse); | ||
return Flux.just(buildChatResponseWithCustomMetadata(prompt, chatResponse)); | ||
}) | ||
.doOnError(observation::error) | ||
.doFinally(s -> observation.stop()) | ||
|
@@ -286,7 +286,10 @@ Prompt buildRequestPrompt(Prompt prompt) { | |
|
||
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); | ||
|
||
return new Prompt(prompt.getInstructions(), requestOptions); | ||
// Uploads media and sets an id to media | ||
List<Message> messagesWithUploadedMediaIds = uploadMedia(prompt.getInstructions()); | ||
|
||
return new Prompt(messagesWithUploadedMediaIds, requestOptions); | ||
} | ||
|
||
private CompletionRequest createRequest(Prompt prompt, boolean stream) { | ||
|
@@ -295,10 +298,7 @@ private CompletionRequest createRequest(Prompt prompt, boolean stream) { | |
if (message instanceof UserMessage userMessage) { | ||
if (!CollectionUtils.isEmpty(userMessage.getMedia())) { | ||
List<UUID> filesIds = userMessage.getMedia().stream() | ||
.map(media -> gigaChatApi | ||
.uploadFile(media) | ||
.getBody() | ||
.id()) | ||
.map(media -> UUID.fromString(media.getId())) | ||
.toList(); | ||
|
||
return List.of(new CompletionRequest.Message( | ||
|
@@ -363,6 +363,55 @@ private CompletionRequest createRequest(Prompt prompt, boolean stream) { | |
return request; | ||
} | ||
|
||
/** | ||
* Загружает медиа файлы, если они переданы в UserMessage, и проставляет к ним id. | ||
* @param messages - исходные сообщения | ||
* @return - обновленные сообщения с проставленными id для media | ||
*/ | ||
private List<Message> uploadMedia(List<Message> messages) { | ||
return messages.stream() | ||
.map(message -> { | ||
if (message instanceof UserMessage userMessage) { | ||
return buildUserMessageWithUploadedMedia(userMessage); | ||
} else { | ||
return message; | ||
} | ||
}) | ||
.toList(); | ||
} | ||
|
||
private UserMessage buildUserMessageWithUploadedMedia(UserMessage userMessage) { | ||
List<Media> mediaList = userMessage.getMedia(); | ||
|
||
// Если нет медиа, то ничего не меняем | ||
if (CollectionUtils.isEmpty(mediaList)) { | ||
return userMessage; | ||
} | ||
var mediaWithIds = mediaList.stream().map(this::uploadMediaAndSetId).toList(); | ||
return UserMessage.builder() | ||
.text(userMessage.getText()) | ||
.metadata(userMessage.getMetadata()) | ||
.media(mediaWithIds) | ||
.build(); | ||
} | ||
|
||
// Загрузка файла в GigaChat, если у media нет id. | ||
private Media uploadMediaAndSetId(Media media) { | ||
// если id указан - значит файл уже загружен и можно возвращать media как есть | ||
if (media.getId() != null) { | ||
return media; | ||
} | ||
// иначе загружаем файл и получаем его id | ||
String mediaId = gigaChatApi.uploadFile(media).getBody().id().toString(); | ||
|
||
return Media.builder() | ||
.id(mediaId) | ||
.name(media.getName()) | ||
.data(media.getData()) | ||
.mimeType(media.getMimeType()) | ||
.build(); | ||
} | ||
|
||
private Object getFunctionCall(GigaChatOptions requestOptions, List<ToolDefinition> toolDefinitions) { | ||
var callMode = requestOptions.getFunctionCallMode(); | ||
|
||
|
@@ -415,13 +464,11 @@ private List<CompletionRequest.FunctionDescription> getFunctionDescriptions(List | |
.toList(); | ||
} | ||
|
||
private ChatResponse toChatResponse( | ||
CompletionResponse completionResponse, Usage usage, boolean streaming, List<Message> conversationHistory) { | ||
private ChatResponse toChatResponse(CompletionResponse completionResponse, Usage usage, boolean streaming) { | ||
List<Generation> generations = completionResponse.getChoices().stream() | ||
.map(choice -> buildGeneration(completionResponse.getId(), choice, streaming)) | ||
.toList(); | ||
return new ChatResponse( | ||
generations, from(completionResponse, usage, Map.of(CONVERSATION_HISTORY, conversationHistory))); | ||
return new ChatResponse(generations, from(completionResponse, usage)); | ||
} | ||
|
||
private Generation buildGeneration(String id, CompletionResponse.Choice choice, boolean streaming) { | ||
|
@@ -455,6 +502,46 @@ private Generation buildGeneration(String id, CompletionResponse.Choice choice, | |
return new Generation(assistantMessage, generationMetadata); | ||
} | ||
|
||
private ChatResponse buildChatResponseWithCustomMetadata(Prompt prompt, ChatResponse originalResponse) { | ||
// т.к. этот метод вызывается при обратном проходе из рекурсии internalCall/internalStream, | ||
// то нужно заполнять метаданные только один раз при первом вызове | ||
if (originalResponse.getMetadata().containsKey(INTERNAL_CONVERSATION_HISTORY)) { | ||
return originalResponse; | ||
} | ||
|
||
List<Message> messages = prompt.getInstructions(); | ||
|
||
// ищем индекс последнего пользовательского сообщения, т.к. здесь могут быть сообщения из ChatMemory | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Также лучше вынести в отдельный метод логику поиска и последнего сообщения также с условием типизации под There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. логику могу вынести, ок - но результатам метода будет индекс, int. |
||
int lastUserMessageIndex = getIndexOfLastUserMessage(messages); | ||
|
||
// Должен включать только AssistantMessage и ToolResponseMessage, | ||
// т.е. внутренние сообщения от GigaChat с параметрами вызова функции, а результаты вызова функций | ||
var internalConversationHistory = new ArrayList<>(messages.subList(lastUserMessageIndex + 1, messages.size())); | ||
var chatResponseBuilder = ChatResponse.builder() | ||
.from(originalResponse) | ||
.metadata(INTERNAL_CONVERSATION_HISTORY, internalConversationHistory); | ||
|
||
// ID загруженных медиа файлов | ||
UserMessage lastUserMessage = (UserMessage) messages.get(lastUserMessageIndex); | ||
if (!CollectionUtils.isEmpty(lastUserMessage.getMedia())) { | ||
chatResponseBuilder.metadata( | ||
UPLOADED_MEDIA_IDS, | ||
lastUserMessage.getMedia().stream().map(Media::getId).toList()); | ||
} | ||
|
||
return chatResponseBuilder.build(); | ||
} | ||
|
||
// Возвращает индекс последнего пользовательского сообщения, или -1, если их нет | ||
private int getIndexOfLastUserMessage(List<Message> messages) { | ||
for (int i = messages.size() - 1; i >= 0; i--) { | ||
if (messages.get(i) instanceof UserMessage) { | ||
return i; | ||
} | ||
} | ||
return -1; | ||
} | ||
|
||
private ChatResponseMetadata from(CompletionResponse completionResponse) { | ||
return from(completionResponse, buildUsage(completionResponse.getUsage())); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Смотрится страшненько, может вынесем в отдельный метод?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
исправил