diff --git a/docs/response-metadata.md b/docs/response-metadata.md new file mode 100644 index 0000000..bbe4cfd --- /dev/null +++ b/docs/response-metadata.md @@ -0,0 +1,27 @@ +## Дополнительная метаинформация в ответе ChatResponse + +Для получения дополнительной информации об ответе создан утилитный класс +[GigaChatResponseUtils](../spring-ai-gigachat/src/main/java/chat/giga/springai/support/GigaChatResponseUtils.java) + +### Получение всей истории переписки с GigaChat + +Если Вам необходимо получить информацию обо всех сообщениях, +которые были получены и отправлены под капотом фреймфорка Spring AI +(например - какие тулы были вызваны, с какими параметрами, каковы результаты вызова тулов), +можно воспользоваться утилитным методом `GigaChatResponseUtils.getConversationHistory(chatResponse)`. + +Пример: + +```java +ChatResponse chatResponse = chatClient + .prompt(question) + .toolCallbacks(GigaTools.from(new WeatherTools())) + .call() + .chatResponse(); +List toolResponseMessages = GigaChatResponseUtils.getConversationHistory(chatResponse) + .stream() + .filter(msg -> MessageType.TOOL.equals(msg.getMessageType())) + .toList(); +log.info("Было вызвано {} функций", toolResponseMessages.size()); +``` + diff --git a/docs/tools.md b/docs/tools.md index 43f420b..b4b4d71 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -1,4 +1,4 @@ -### Вызов пользовательских функций +## Вызов пользовательских функций Для вызова внешних функций пользуйтесь официальной документацией Spring AI - https://docs.spring.io/spring-ai/reference/api/tools.html. diff --git a/spring-ai-gigachat-example/README.md b/spring-ai-gigachat-example/README.md index 27d26c4..af49340 100644 --- a/spring-ai-gigachat-example/README.md +++ b/spring-ai-gigachat-example/README.md @@ -75,7 +75,7 @@ curl localhost:8080/tool/v1/weather -d "Какая температура в К curl localhost:8080/tool/v2/weather -d "Сколько градусов в Спб?" -H "content-type:application/json" curl localhost:8080/tool/v3/weather -d "Сколько градусов в Москве?" -H "content-type:application/json" curl localhost:8080/tool/v4/weather -d "Сколько градусов в Сочи будет завтра?" -H "content-type:application/json" -curl localhost:8080/tool/v4/weather -d "Какое давление в Сочи будет завтра?" -H "content-type:application/json" +curl localhost:8080/tool/v4/weather -d "Сколько градусов и какое давление в Сочи будет завтра?" -H "content-type:application/json" ``` ## Примеры использования RAG diff --git a/spring-ai-gigachat-example/src/main/java/chat/giga/springai/example/WeatherToolController.java b/spring-ai-gigachat-example/src/main/java/chat/giga/springai/example/WeatherToolController.java index 21974c7..90faa0c 100644 --- a/spring-ai-gigachat-example/src/main/java/chat/giga/springai/example/WeatherToolController.java +++ b/spring-ai-gigachat-example/src/main/java/chat/giga/springai/example/WeatherToolController.java @@ -2,15 +2,20 @@ import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE; +import chat.giga.springai.support.GigaChatResponseUtils; import chat.giga.springai.tool.GigaTools; import chat.giga.springai.tool.annotation.FewShotExample; import chat.giga.springai.tool.annotation.GigaTool; import chat.giga.springai.tool.function.GigaFunctionToolCallback; import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import java.util.List; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.ai.tool.function.FunctionToolCallback; @@ -150,12 +155,18 @@ String getTemperature( */ @PostMapping("tool/v4/weather") public String weatherToolAnnotation(@RequestBody String question) { - return chatClient + ChatResponse chatResponse = chatClient .prompt(question) // Важно использовать .toolCallbacks(GigaTools.from()), чтобы обрабатывались аннотации @GigaTool и @Tool // Если использовать конструкцию .tools(new WeatherTools()), то будет использоваться только @Tool .toolCallbacks(GigaTools.from(new WeatherTools())) .call() - .content(); + .chatResponse(); + List toolResponseMessages = GigaChatResponseUtils.getConversationHistory(chatResponse).stream() + .filter(msg -> MessageType.TOOL.equals(msg.getMessageType())) + .toList(); + log.info("Было вызвано {} функций", toolResponseMessages.size()); + toolResponseMessages.forEach(msg -> log.info(msg.toString())); + return chatResponse.getResult().getOutput().getText(); } } diff --git a/spring-ai-gigachat/src/main/java/chat/giga/springai/GigaChatModel.java b/spring-ai-gigachat/src/main/java/chat/giga/springai/GigaChatModel.java index ebf580c..74771f0 100644 --- a/spring-ai-gigachat/src/main/java/chat/giga/springai/GigaChatModel.java +++ b/spring-ai-gigachat/src/main/java/chat/giga/springai/GigaChatModel.java @@ -16,10 +16,7 @@ import java.util.stream.Collectors; import lombok.Setter; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.messages.*; import org.springframework.ai.chat.metadata.*; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; @@ -50,6 +47,7 @@ public class GigaChatModel implements ChatModel { public static final String DEFAULT_MODEL_NAME = GigaChatApi.ChatModel.GIGA_CHAT_2.getName(); public static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); + public static final String CONVERSATION_HISTORY = "conversationHistory"; private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); @@ -140,13 +138,20 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons ctx -> this.gigaChatApi.chatCompletionEntity(request, buildHeaders(prompt.getOptions()))); CompletionResponse completionResponse = completionEntity.getBody(); + + if (completionResponse == null) { + log.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + completionResponse.setId(completionEntity.getHeaders().getFirst(X_REQUEST_ID)); Usage currentChatResponseUsage = buildUsage(completionResponse.getUsage()); Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); - ChatResponse chatResponse = toChatResponse(completionResponse, accumulatedUsage, false); + ChatResponse chatResponse = + toChatResponse(completionResponse, accumulatedUsage, false, prompt.getInstructions()); observationContext.setResponse(chatResponse); return chatResponse; @@ -200,11 +205,16 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha ctx -> this.gigaChatApi.chatCompletionStream(request, buildHeaders(prompt.getOptions()))); Flux chatResponseFlux = response.switchMap(completionResponse -> { + if (completionResponse == null) { + log.warn("No chat completion returned for prompt: {}", prompt); + return Flux.just(new ChatResponse(List.of())); + } Usage currentChatResponseUsage = buildUsage(completionResponse.getUsage()); Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); - ChatResponse chatResponse = toChatResponse(completionResponse, accumulatedUsage, true); + ChatResponse chatResponse = + toChatResponse(completionResponse, accumulatedUsage, true, prompt.getInstructions()); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired( prompt.getOptions(), chatResponse)) { @@ -405,16 +415,13 @@ private List getFunctionDescriptions(List .toList(); } - private ChatResponse toChatResponse(CompletionResponse completionResponse, Usage usage, boolean streaming) { - if (completionResponse == null) { - log.warn("Null completion response"); - return new ChatResponse(List.of()); - } - + private ChatResponse toChatResponse( + CompletionResponse completionResponse, Usage usage, boolean streaming, List conversationHistory) { List generations = completionResponse.getChoices().stream() .map(choice -> buildGeneration(completionResponse.getId(), choice, streaming)) .toList(); - return new ChatResponse(generations, from(completionResponse, usage)); + return new ChatResponse( + generations, from(completionResponse, usage, Map.of(CONVERSATION_HISTORY, conversationHistory))); } private Generation buildGeneration(String id, CompletionResponse.Choice choice, boolean streaming) { @@ -453,6 +460,11 @@ private ChatResponseMetadata from(CompletionResponse completionResponse) { } private ChatResponseMetadata from(CompletionResponse completionResponse, Usage usage) { + return from(completionResponse, usage, Map.of()); + } + + private ChatResponseMetadata from( + CompletionResponse completionResponse, Usage usage, Map metadata) { Assert.notNull(completionResponse, "GigaChat CompletionResponse must not be null"); return ChatResponseMetadata.builder() .id(completionResponse.getId()) @@ -460,6 +472,7 @@ private ChatResponseMetadata from(CompletionResponse completionResponse, Usage u .usage(usage) .keyValue("created", completionResponse.getCreated()) .keyValue("object", completionResponse.getObject()) + .metadata(metadata) .build(); } diff --git a/spring-ai-gigachat/src/main/java/chat/giga/springai/support/GigaChatResponseUtils.java b/spring-ai-gigachat/src/main/java/chat/giga/springai/support/GigaChatResponseUtils.java new file mode 100644 index 0000000..5a80640 --- /dev/null +++ b/spring-ai-gigachat/src/main/java/chat/giga/springai/support/GigaChatResponseUtils.java @@ -0,0 +1,25 @@ +package chat.giga.springai.support; + +import chat.giga.springai.GigaChatModel; +import java.util.List; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.ChatResponse; + +/** + * Utils for providing type-safe access to ChatResponse metadata properties from GigaChat model response. + * + * @author Linar Abzaltdinov + */ +@UtilityClass +@Slf4j +public class GigaChatResponseUtils { + public static List getConversationHistory(ChatResponse chatResponse) { + if (chatResponse != null && chatResponse.getMetadata() != null) { + List messages = chatResponse.getMetadata().get(GigaChatModel.CONVERSATION_HISTORY); + return messages == null ? List.of() : messages; + } + return List.of(); + } +} diff --git a/spring-ai-gigachat/src/test/java/chat/giga/springai/support/GigaChatResponseUtilsTest.java b/spring-ai-gigachat/src/test/java/chat/giga/springai/support/GigaChatResponseUtilsTest.java new file mode 100644 index 0000000..3c04ff1 --- /dev/null +++ b/spring-ai-gigachat/src/test/java/chat/giga/springai/support/GigaChatResponseUtilsTest.java @@ -0,0 +1,55 @@ +package chat.giga.springai.support; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +import chat.giga.springai.GigaChatModel; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.model.ChatResponse; + +@ExtendWith(MockitoExtension.class) +public class GigaChatResponseUtilsTest { + + @Mock + private ChatResponse chatResponse; + + @Test + public void testGetConversationHistoryWithNullChatResponse() { + List result = GigaChatResponseUtils.getConversationHistory(null); + assertEquals(Collections.emptyList(), result); + } + + @Test + public void testGetConversationHistoryWithNullMetadata() { + when(chatResponse.getMetadata()).thenReturn(null); + List result = GigaChatResponseUtils.getConversationHistory(chatResponse); + assertEquals(Collections.emptyList(), result); + } + + @Test + public void testGetConversationHistoryWithEmptyMetadata() { + when(chatResponse.getMetadata()) + .thenReturn(ChatResponseMetadata.builder().build()); + List result = GigaChatResponseUtils.getConversationHistory(chatResponse); + assertEquals(Collections.emptyList(), result); + } + + @Test + public void testGetConversationHistoryWithConversationHistoryInMetadata() { + Message message = new UserMessage("Hello, world!"); + var metadata = ChatResponseMetadata.builder() + .keyValue(GigaChatModel.CONVERSATION_HISTORY, List.of(message)) + .build(); + when(chatResponse.getMetadata()).thenReturn(metadata); + List result = GigaChatResponseUtils.getConversationHistory(chatResponse); + assertEquals(Collections.singletonList(message), result); + } +}