Skip to content

Commit d71f90f

Browse files
committed
add OpenAiApi tests
1 parent b49458b commit d71f90f

File tree

3 files changed

+160
-1
lines changed

3 files changed

+160
-1
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,17 @@ public class OpenAiApi {
6666
* @param openAiToken OpenAI apiKey.
6767
*/
6868
public OpenAiApi(String openAiToken) {
69-
this(DEFAULT_BASE_URL, openAiToken, RestClient.builder());
69+
this(DEFAULT_BASE_URL, openAiToken);
70+
}
71+
72+
/**
73+
* Create a new chat completion api.
74+
*
75+
* @param baseUrl api base URL.
76+
* @param openAiToken OpenAI apiKey.
77+
*/
78+
public OpenAiApi(String baseUrl, String openAiToken) {
79+
this(baseUrl, openAiToken, RestClient.builder());
7080
}
7181

7282
/**
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.api;
18+
19+
import java.util.List;
20+
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
23+
import reactor.core.publisher.Flux;
24+
25+
import org.springframework.ai.openai.api.OpenAiApi;
26+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
27+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
28+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
29+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
30+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
31+
import org.springframework.ai.openai.api.OpenAiApi.Embedding;
32+
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
33+
import org.springframework.http.ResponseEntity;
34+
35+
import static org.assertj.core.api.Assertions.assertThat;
36+
37+
/**
38+
* @author Christian Tzolov
39+
*/
40+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
41+
public class OpenAiApiIT {
42+
43+
OpenAiApi openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
44+
45+
@Test
46+
void chatCompletionEntity() {
47+
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
48+
ResponseEntity<ChatCompletion> response = openAiApi.chatCompletionEntity(
49+
new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8f, false));
50+
51+
assertThat(response).isNotNull();
52+
assertThat(response.getBody()).isNotNull();
53+
}
54+
55+
@Test
56+
void chatCompletionStream() {
57+
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
58+
Flux<ChatCompletionChunk> response = openAiApi.chatCompletionStream(
59+
new ChatCompletionRequest(List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8f, true));
60+
61+
assertThat(response).isNotNull();
62+
assertThat(response.collectList().block()).isNotNull();
63+
}
64+
65+
@Test
66+
void embeddings() {
67+
ResponseEntity<EmbeddingList<Embedding>> response = openAiApi
68+
.embeddings(new OpenAiApi.EmbeddingRequest<String>("Hello world"));
69+
70+
assertThat(response).isNotNull();
71+
assertThat(response.getBody().data()).hasSize(1);
72+
assertThat(response.getBody().data().get(0).embedding()).hasSize(1536);
73+
}
74+
75+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.api;
18+
19+
import org.junit.jupiter.api.Disabled;
20+
import org.junit.jupiter.api.Test;
21+
22+
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
23+
import org.springframework.http.client.SimpleClientHttpRequestFactory;
24+
import org.springframework.web.client.RestClient;
25+
import org.springframework.web.client.RestClient.Builder;
26+
import org.springframework.web.client.RestTemplate;
27+
import org.springframework.web.util.DefaultUriBuilderFactory;
28+
29+
import static org.assertj.core.api.Assertions.assertThat;
30+
31+
/**
32+
* @author Christian Tzolov
33+
*/
34+
public class RestClientBuilderTests {
35+
36+
public static final String BASE_URL = "https://dog.ceo";
37+
38+
@Test
39+
public void test1() {
40+
test(RestClient.builder(), BASE_URL);
41+
}
42+
43+
@Test
44+
@Disabled("RestClient.builder(restTemplate) bug: https://github.com/spring-projects/spring-framework/issues/32180")
45+
public void test2() {
46+
RestTemplate restTemplate = new RestTemplate();
47+
test(RestClient.builder(restTemplate), BASE_URL);
48+
}
49+
50+
@Test
51+
public void test3() {
52+
RestTemplate restTemplate = new RestTemplate();
53+
restTemplate.setUriTemplateHandler(new DefaultUriBuilderFactory(BASE_URL));
54+
test(RestClient.builder(restTemplate), BASE_URL);
55+
}
56+
57+
@Test
58+
public void test4() {
59+
var clientHttpRequestFactory = new SimpleClientHttpRequestFactory();
60+
clientHttpRequestFactory.setConnectTimeout(5000);
61+
// clientHttpRequestFactory.setProxy(new Proxy(Type.HTTP,
62+
// InetSocketAddress.createUnresolved("localhost", 80)));
63+
RestClient.Builder builder = RestClient.builder().requestFactory(clientHttpRequestFactory);
64+
test(builder, BASE_URL);
65+
}
66+
67+
private void test(Builder restClientBuilder, String baseUrl) {
68+
var restClient = restClientBuilder.baseUrl(baseUrl).build();
69+
String res = restClient.get().uri("/api/breeds/list/all").retrieve().body(String.class);
70+
71+
assertThat(res).isNotNull();
72+
}
73+
74+
}

0 commit comments

Comments
 (0)