Skip to content

Доступ к id загруженных файлов + Исправление сохранения внутренних сообщений Spring AI в ChatResponseMetadata #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions docs/response-metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
Для получения дополнительной информации об ответе создан утилитный класс
[GigaChatResponseUtils](../spring-ai-gigachat/src/main/java/chat/giga/springai/support/GigaChatResponseUtils.java)

### Получение всей истории переписки с GigaChat
### Получение всей переписки с GigaChat под капотом Spring AI

Если Вам необходимо получить информацию обо всех сообщениях,
которые были получены и отправлены под капотом фреймфорка Spring AI
(например - какие тулы были вызваны, с какими параметрами, каковы результаты вызова тулов),
можно воспользоваться утилитным методом `GigaChatResponseUtils.getConversationHistory(chatResponse)`.
можно воспользоваться утилитным методом `GigaChatResponseUtils.getInternalMessages(chatResponse)`.

Пример:

Expand All @@ -18,10 +18,40 @@ ChatResponse chatResponse = chatClient
.toolCallbacks(GigaTools.from(new WeatherTools()))
.call()
.chatResponse();
List<Message> toolResponseMessages = GigaChatResponseUtils.getConversationHistory(chatResponse)
List<Message> toolResponseMessages = GigaChatResponseUtils.getInternalMessages(chatResponse)
.stream()
.filter(msg -> MessageType.TOOL.equals(msg.getMessageType()))
.toList();
log.info("Было вызвано {} функций", toolResponseMessages.size());
```

### Получение иденификаторов загруженных файлов при использовании Multimodality

Если Вам необходимо получить идентификаторы загруженных файлов,
(например - чтобы затем повторно использовать их в запросе или наоборот удалить),
можно воспользоваться утилитным методом `GigaChatResponseUtils.getUploadedMediaIds(chatResponse)`.

Пример:

```java
ChatResponse chatResponse = chatClient
.prompt()
.user(u -> u.text("Какая порода кота на фото?")
.media(new Media(MimeType.valueOf(multipartFile.getContentType()), multipartFile.getResource())))
.call()
.chatResponse();

String mediaId = GigaChatResponseUtils.getUploadedMediaIds(chatResponse).get(0);

// переиспользуем загруженные файлы в последующем запросе
chatClient.prompt()
.user(u -> u.text("Какого цвета шерсть у кота на фото?")
// при повторном использовании mimeType и data не важны, но не должны быть null
.media(Media.builder().id(mediaId).mimeType(MediaType.ALL).data("").build()))
.call()
.chatResponse();

// удаляем файлы
uploadedMediaIds.forEach(gigaChatApi::deleteFile);
```

Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;

import chat.giga.springai.GigaChatOptions;
import chat.giga.springai.api.chat.GigaChatApi;
import chat.giga.springai.support.GigaChatResponseUtils;
import java.util.List;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.content.Media;
import org.springframework.http.MediaType;
import org.springframework.util.MimeType;
Expand All @@ -16,8 +20,10 @@
public class MultimodalityController {

private final ChatClient chatClient;
private final GigaChatApi gigaChatApi;

public MultimodalityController(ChatClient.Builder chatClientBuilder) {
public MultimodalityController(ChatClient.Builder chatClientBuilder, GigaChatApi gigaChatApi) {
this.gigaChatApi = gigaChatApi;
this.chatClient = chatClientBuilder
.defaultAdvisors(new SimpleLoggerAdvisor())
.defaultOptions(
Expand All @@ -28,12 +34,17 @@ public MultimodalityController(ChatClient.Builder chatClientBuilder) {
@PostMapping(value = "/multimodality/chat", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public String chatWithMultimodality(
@RequestParam String userMessage, @RequestParam("file") MultipartFile multipartFile) {
return chatClient
ChatResponse chatResponse = chatClient
.prompt()
.user(u -> u.text(userMessage)
.media(new Media(
MimeType.valueOf(multipartFile.getContentType()), multipartFile.getResource())))
.call()
.content();
.chatResponse();

List<String> uploadedMediaIds = GigaChatResponseUtils.getUploadedMediaIds(chatResponse);
uploadedMediaIds.forEach(gigaChatApi::deleteFile);

return chatResponse.getResult().getOutput().getText();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public String weatherToolAnnotation(@RequestBody String question) {
.toolCallbacks(GigaTools.from(new WeatherTools()))
.call()
.chatResponse();
List<Message> toolResponseMessages = GigaChatResponseUtils.getConversationHistory(chatResponse).stream()
List<Message> toolResponseMessages = GigaChatResponseUtils.getInternalMessages(chatResponse).stream()
.filter(msg -> MessageType.TOOL.equals(msg.getMessageType()))
.toList();
log.info("Было вызвано {} функций", toolResponseMessages.size());
Expand Down
119 changes: 103 additions & 16 deletions spring-ai-gigachat/src/main/java/chat/giga/springai/GigaChatModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.*;
import org.springframework.ai.retry.RetryUtils;
Expand All @@ -47,7 +48,8 @@ public class GigaChatModel implements ChatModel {
public static final String DEFAULT_MODEL_NAME = GigaChatApi.ChatModel.GIGA_CHAT_2.getName();
public static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION =
new DefaultChatModelObservationConvention();
public static final String CONVERSATION_HISTORY = "conversationHistory";
public static final String INTERNAL_CONVERSATION_HISTORY = "GigaChatInternalConversationHistory";
public static final String UPLOADED_MEDIA_IDS = "GigaChatUploadedMediaIds";
private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER =
ToolCallingManager.builder().build();

Expand Down Expand Up @@ -150,8 +152,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
Usage accumulatedUsage =
UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);

ChatResponse chatResponse =
toChatResponse(completionResponse, accumulatedUsage, false, prompt.getInstructions());
ChatResponse chatResponse = toChatResponse(completionResponse, accumulatedUsage, false);
observationContext.setResponse(chatResponse);

return chatResponse;
Expand All @@ -172,7 +173,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
}
}

return response;
return buildChatResponseWithCustomMetadata(prompt, response);
}

@Override
Expand Down Expand Up @@ -213,8 +214,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
Usage accumulatedUsage =
UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);

ChatResponse chatResponse =
toChatResponse(completionResponse, accumulatedUsage, true, prompt.getInstructions());
ChatResponse chatResponse = toChatResponse(completionResponse, accumulatedUsage, true);

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(
prompt.getOptions(), chatResponse)) {
Expand All @@ -233,7 +233,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
}
}

return Flux.just(chatResponse);
return Flux.just(buildChatResponseWithCustomMetadata(prompt, chatResponse));
})
.doOnError(observation::error)
.doFinally(s -> observation.stop())
Expand Down Expand Up @@ -286,7 +286,10 @@ Prompt buildRequestPrompt(Prompt prompt) {

ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());

return new Prompt(prompt.getInstructions(), requestOptions);
// Uploads media and sets an id to media
List<Message> messagesWithUploadedMediaIds = uploadMedia(prompt.getInstructions());

return new Prompt(messagesWithUploadedMediaIds, requestOptions);
}

private CompletionRequest createRequest(Prompt prompt, boolean stream) {
Expand All @@ -295,10 +298,7 @@ private CompletionRequest createRequest(Prompt prompt, boolean stream) {
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<UUID> filesIds = userMessage.getMedia().stream()
.map(media -> gigaChatApi
.uploadFile(media)
.getBody()
.id())
.map(media -> UUID.fromString(media.getId()))
.toList();

return List.of(new CompletionRequest.Message(
Expand Down Expand Up @@ -363,6 +363,55 @@ private CompletionRequest createRequest(Prompt prompt, boolean stream) {
return request;
}

/**
* Загружает медиа файлы, если они переданы в UserMessage, и проставляет к ним id.
* @param messages - исходные сообщения
* @return - обновленные сообщения с проставленными id для media
*/
private List<Message> uploadMedia(List<Message> messages) {
return messages.stream()
.map(message -> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Смотрится страшненько, может вынесем в отдельный метод?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

исправил

if (message instanceof UserMessage userMessage) {
return buildUserMessageWithUploadedMedia(userMessage);
} else {
return message;
}
})
.toList();
}

private UserMessage buildUserMessageWithUploadedMedia(UserMessage userMessage) {
List<Media> mediaList = userMessage.getMedia();

// Если нет медиа, то ничего не меняем
if (CollectionUtils.isEmpty(mediaList)) {
return userMessage;
}
var mediaWithIds = mediaList.stream().map(this::uploadMediaAndSetId).toList();
return UserMessage.builder()
.text(userMessage.getText())
.metadata(userMessage.getMetadata())
.media(mediaWithIds)
.build();
}

// Загрузка файла в GigaChat, если у media нет id.
private Media uploadMediaAndSetId(Media media) {
// если id указан - значит файл уже загружен и можно возвращать media как есть
if (media.getId() != null) {
return media;
}
// иначе загружаем файл и получаем его id
String mediaId = gigaChatApi.uploadFile(media).getBody().id().toString();

return Media.builder()
.id(mediaId)
.name(media.getName())
.data(media.getData())
.mimeType(media.getMimeType())
.build();
}

private Object getFunctionCall(GigaChatOptions requestOptions, List<ToolDefinition> toolDefinitions) {
var callMode = requestOptions.getFunctionCallMode();

Expand Down Expand Up @@ -415,13 +464,11 @@ private List<CompletionRequest.FunctionDescription> getFunctionDescriptions(List
.toList();
}

private ChatResponse toChatResponse(
CompletionResponse completionResponse, Usage usage, boolean streaming, List<Message> conversationHistory) {
private ChatResponse toChatResponse(CompletionResponse completionResponse, Usage usage, boolean streaming) {
List<Generation> generations = completionResponse.getChoices().stream()
.map(choice -> buildGeneration(completionResponse.getId(), choice, streaming))
.toList();
return new ChatResponse(
generations, from(completionResponse, usage, Map.of(CONVERSATION_HISTORY, conversationHistory)));
return new ChatResponse(generations, from(completionResponse, usage));
}

private Generation buildGeneration(String id, CompletionResponse.Choice choice, boolean streaming) {
Expand Down Expand Up @@ -455,6 +502,46 @@ private Generation buildGeneration(String id, CompletionResponse.Choice choice,
return new Generation(assistantMessage, generationMetadata);
}

private ChatResponse buildChatResponseWithCustomMetadata(Prompt prompt, ChatResponse originalResponse) {
// т.к. этот метод вызывается при обратном проходе из рекурсии internalCall/internalStream,
// то нужно заполнять метаданные только один раз при первом вызове
if (originalResponse.getMetadata().containsKey(INTERNAL_CONVERSATION_HISTORY)) {
return originalResponse;
}

List<Message> messages = prompt.getInstructions();

// ищем индекс последнего пользовательского сообщения, т.к. здесь могут быть сообщения из ChatMemory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Также лучше вынести в отдельный метод логику поиска и последнего сообщения также с условием типизации под UserMessage, так как ниже есть оптимистичный каст под него, что не очень хорошо

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

логику могу вынести, ок - но результатам метода будет индекс, int.
мне в одном месте нужен индекс, а в другом уже объект UserMessage, от каста не избавлюсь

int lastUserMessageIndex = getIndexOfLastUserMessage(messages);

// Должен включать только AssistantMessage и ToolResponseMessage,
// т.е. внутренние сообщения от GigaChat с параметрами вызова функции, а результаты вызова функций
var internalConversationHistory = new ArrayList<>(messages.subList(lastUserMessageIndex + 1, messages.size()));
var chatResponseBuilder = ChatResponse.builder()
.from(originalResponse)
.metadata(INTERNAL_CONVERSATION_HISTORY, internalConversationHistory);

// ID загруженных медиа файлов
UserMessage lastUserMessage = (UserMessage) messages.get(lastUserMessageIndex);
if (!CollectionUtils.isEmpty(lastUserMessage.getMedia())) {
chatResponseBuilder.metadata(
UPLOADED_MEDIA_IDS,
lastUserMessage.getMedia().stream().map(Media::getId).toList());
}

return chatResponseBuilder.build();
}

// Возвращает индекс последнего пользовательского сообщения, или -1, если их нет
private int getIndexOfLastUserMessage(List<Message> messages) {
for (int i = messages.size() - 1; i >= 0; i--) {
if (messages.get(i) instanceof UserMessage) {
return i;
}
}
return -1;
}

private ChatResponseMetadata from(CompletionResponse completionResponse) {
return from(completionResponse, buildUsage(completionResponse.getUsage()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
@UtilityClass
@Slf4j
public class GigaChatResponseUtils {
public static List<Message> getConversationHistory(ChatResponse chatResponse) {
public static List<Message> getInternalMessages(ChatResponse chatResponse) {
return getFromMetadata(chatResponse, GigaChatModel.INTERNAL_CONVERSATION_HISTORY, List.of());
}

public static List<String> getUploadedMediaIds(ChatResponse chatResponse) {
return getFromMetadata(chatResponse, GigaChatModel.UPLOADED_MEDIA_IDS, List.of());
}

private static <T> T getFromMetadata(ChatResponse chatResponse, String key, T defaultValue) {
if (chatResponse != null && chatResponse.getMetadata() != null) {
List<Message> messages = chatResponse.getMetadata().get(GigaChatModel.CONVERSATION_HISTORY);
return messages == null ? List.of() : messages;
T data = chatResponse.getMetadata().get(key);
return data == null ? defaultValue : data;
}
return List.of();
return defaultValue;
}
}
Loading