This repository was archived by the owner on Jun 6, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
This repository was archived by the owner on Jun 6, 2024. It is now read-only.
funcation call + stream进行调用时返回的ChatFunctionCall对应中arguments丢失 #505
Copy link
Copy link
Open
Description
这个是我的示例代码:
public static void main(String[] args) throws UnknownHostException, InterruptedException {
try {
ObjectMapper mapper = defaultObjectMapper();
OkHttpClient client = defaultClient("", Duration.of(10000L,ChronoUnit.SECONDS))
.newBuilder()
.build();
Retrofit retrofit = defaultRetrofit(client, mapper);
Class<Retrofit> clazz = Retrofit.class;
Field baseUrl = clazz.getDeclaredField("baseUrl");
baseUrl.setAccessible(true);
baseUrl.set(retrofit, HttpUrl.get(BASE_URL));
OpenAiApi api = retrofit.create(OpenAiApi.class);
OpenAiService service = new OpenAiService(api);
List<ChatMessage> messages = Lists.newArrayList();
messages.add(new ChatMessage("system", "Please use the functions provided below to determine what function needs to be called for the user's problem. " +
"If the necessary parameters are missing when calling the function, please return to the user in this format and prompt the user to pass the necessary parameters:\n" +
"We also need the following information to complete your request: Required Parameter 1, Required Parameter 2\n" +
"Make sure your prompts are accurate, polite, and the directly relevant information is obvious and understandable to users"));
Scanner scanner = new Scanner(System.in);
//"Tell me the weather"
messages.add(new ChatMessage("user", scanner.nextLine()));
while (true) {
ChatFunctionDynamic chatFunctionDynamic = getChatFunctionDynamic();
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("qwen15-110b.credit-llm")
.messages(messages)
.n(1)
.maxTokens(256)
.functions(Lists.newArrayList(chatFunctionDynamic))
.functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of("auto"))
.build();
Flowable<ChatCompletionChunk> flowable = service.streamChatCompletion(chatCompletionRequest);
AtomicBoolean isFirst = new AtomicBoolean(true);
ChatMessage responseMessage = service.mapStreamToAccumulator(flowable).doOnNext(accumulator -> {
if (accumulator.isFunctionCall()) {
ChatFunctionCall functionCall = accumulator.getAccumulatedChatFunctionCall();
if (isFirst.getAndSet(false)) {
System.out.println("Executing function " + functionCall.getName() + "...");
}
} else {
if (isFirst.getAndSet(false)) {
System.out.print("Response: ");
}
if (accumulator.getMessageChunk().getContent() != null) {
System.out.print(accumulator.getMessageChunk().getContent());
}
}
})
.doOnComplete(System.out::println)
.lastElement()
.blockingGet()
.getAccumulatedMessage();
messages.add(responseMessage);
ChatFunctionCall functionCall = responseMessage.getFunctionCall();
if (functionCall != null) {
if (functionCall.getName().equals("get_weather")) {
String location = functionCall.getArguments().get("location").asText();
String unit = functionCall.getArguments().get("unit").asText();
WeatherResponse weather = getWeather(location, unit);
ChatMessage weatherMessage = new ChatMessage(ChatMessageRole.FUNCTION.value(), JSON.toJSONString(weather), "get_weather");
messages.add(weatherMessage);
continue;
}
}
System.out.print("Next Query: ");
String nextLine = scanner.nextLine();
if (nextLine.equalsIgnoreCase("exit")) {
System.exit(0);
}
messages.add(new ChatMessage(ChatMessageRole.USER.value(), nextLine));
}
} catch (Exception e) {
e.printStackTrace();
}
}
private static WeatherResponse getWeather(String location, String unit) {
return new WeatherResponse(location, WeatherUnit.valueOf(unit), new Random().nextInt(40), "sunny");
}
public static ChatFunctionDynamic getChatFunctionDynamic() {
return ChatFunctionDynamic.builder()
.name("get_weather")
.description("Get the current weather of a location")
.addProperty(ChatFunctionProperty.builder()
.name("location")
.type("string")
.description("City and state, for example: León, Guanajuato")
.build())
.addProperty(ChatFunctionProperty.builder()
.name("unit")
.type("string")
.description("The temperature unit, can be 'CELSIUS' or 'FAHRENHEIT'")
.enumValues(new HashSet<>(Arrays.asList("CELSIUS", "FAHRENHEIT")))
.required(true)
.build())
.build();
}
对应的报错信息,在String location = functionCall.getArguments().get("location").asText();该行报错
java.lang.NullPointerException
at com.mybank.bkinfocenter.common.recognition.web.Test.main(Test.java:96)
debug代码查看
com.theokanning.openai.service.OpenAiService#mapStreamToAccumulator方法中messageChunk中的arguments类型为objectNode,从而导致asText()方法返回的结果为""
请问我可以用什么简单的方法在不修改源代码的情况下来解决这个问题,非常感谢!
Metadata
Metadata
Assignees
Labels
No labels