Skip to content

Commit b7cf095

Browse files
author
Denis Lobo
committed
add hugging face text to image integration
1 parent 175dc9b commit b7cf095

File tree

9 files changed

+818
-9
lines changed

9 files changed

+818
-9
lines changed

models/spring-ai-huggingface/pom.xml

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
<version>3.0.46</version>
9494
<executions>
9595
<execution>
96+
<id>generate-chat-api</id>
9697
<goals>
9798
<goal>generate</goal>
9899
</goals>
@@ -101,7 +102,7 @@
101102
<language>java</language>
102103
<library>resttemplate</library>
103104
<apiPackage>org.springframework.ai.huggingface.api</apiPackage>
104-
<modelPackage>org.springframework.ai.huggingface.model</modelPackage>
105+
<modelPackage>org.springframework.ai.huggingface.model.chat</modelPackage>
105106
<invokerPackage>org.springframework.ai.huggingface.invoker</invokerPackage>
106107
<generateApiTests>false</generateApiTests>
107108
<generateModelTests>false</generateModelTests>
@@ -113,6 +114,30 @@
113114
</configOptions>
114115
</configuration>
115116
</execution>
117+
<execution>
118+
<id>generate-imagegen-api</id>
119+
<goals>
120+
<goal>generate</goal>
121+
</goals>
122+
<configuration>
123+
<inputSpec>${project.basedir}/src/main/resources/openapi-imagegen.json</inputSpec>
124+
<language>java</language>
125+
<library>resttemplate</library>
126+
<apiPackage>org.springframework.ai.huggingface.api</apiPackage>
127+
<modelPackage>org.springframework.ai.huggingface.model.imagegen</modelPackage>
128+
<invokerPackage>org.springframework.ai.huggingface.invoker</invokerPackage>
129+
<generateApiTests>false</generateApiTests>
130+
<generateModelTests>false</generateModelTests>
131+
<!-- use custom codegen-template to avoid accept-header selection in generated inference api -->
132+
<templateDirectory>src/main/resources/swagger-codegen/templates/Java</templateDirectory>
133+
<configOptions>
134+
<sourceFolder>src/main/java</sourceFolder>
135+
<dateLibrary>java8</dateLibrary>
136+
<!-- jackson secret sauce!! -->
137+
<notNullJacksonAnnotation>true</notNullJacksonAnnotation>
138+
</configOptions>
139+
</configuration>
140+
</execution>
116141
</executions>
117142
</plugin>
118143

models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616

1717
package org.springframework.ai.huggingface;
1818

19-
import java.util.ArrayList;
20-
import java.util.List;
21-
import java.util.Map;
22-
2319
import com.fasterxml.jackson.core.type.TypeReference;
2420
import com.fasterxml.jackson.databind.ObjectMapper;
2521

@@ -32,10 +28,12 @@
3228
import org.springframework.ai.chat.prompt.Prompt;
3329
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
3430
import org.springframework.ai.huggingface.invoker.ApiClient;
35-
import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails;
36-
import org.springframework.ai.huggingface.model.CompatGenerateRequest;
37-
import org.springframework.ai.huggingface.model.GenerateParameters;
38-
import org.springframework.ai.huggingface.model.GenerateResponse;
31+
32+
import org.springframework.ai.huggingface.model.chat.*;
33+
34+
import java.util.ArrayList;
35+
import java.util.List;
36+
import java.util.Map;
3937

4038
/**
4139
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference
@@ -112,6 +110,10 @@ public ChatResponse call(Prompt prompt) {
112110
return new ChatResponse(generations);
113111
}
114112

113+
public Info info() {
114+
return this.textGenApi.getModelInfo();
115+
}
116+
115117
/**
116118
* Gets the maximum number of new tokens to be generated.
117119
* @return The maximum number of new tokens.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package org.springframework.ai.huggingface;
2+
3+
import org.springframework.ai.huggingface.api.ImageGenerationInferenceApi;
4+
import org.springframework.ai.huggingface.text2image.HuggingfaceImageOptions;
5+
import org.springframework.ai.huggingface.invoker.ApiClient;
6+
import org.springframework.ai.huggingface.model.imagegen.GenerateParameters;
7+
import org.springframework.ai.huggingface.model.imagegen.GenerateRequest;
8+
import org.springframework.ai.image.*;
9+
10+
import java.util.Base64;
11+
import java.util.List;
12+
13+
/**
14+
* An implementation of {@link ImageModel} that interfaces with HuggingFace Inference
15+
* Endpoints for text-to-image generation.
16+
*
17+
* @author Denis Lobo
18+
*/
19+
public class HuggingfaceImageModel implements ImageModel {
20+
21+
/**
22+
* Token required for authenticating with the HuggingFace Inference API.
23+
*/
24+
private final String apiToken;
25+
26+
/**
27+
* Client for making API calls.
28+
*/
29+
private ApiClient apiClient = new ApiClient();
30+
31+
private ImageGenerationInferenceApi imageGenApi = new ImageGenerationInferenceApi();
32+
33+
/**
34+
* Constructs a new HuggingfaceImageModel with the specified API token and base path.
35+
* @param apiToken The API token for HuggingFace.
36+
* @param basePath The base path for API requests.
37+
*/
38+
public HuggingfaceImageModel(final String apiToken, String basePath) {
39+
this.apiToken = apiToken;
40+
this.apiClient.setBasePath(basePath);
41+
this.apiClient.addDefaultHeader("Authorization", "Bearer " + this.apiToken);
42+
this.imageGenApi.setApiClient(this.apiClient);
43+
}
44+
45+
@Override
46+
public ImageResponse call(ImagePrompt prompt) {
47+
final GenerateParameters generateParameters = createGenerateParameters(prompt.getOptions());
48+
final GenerateRequest generateRequest = createGenerateRequest(prompt.getInstructions(), generateParameters);
49+
50+
// huggingface eps with text-to-image models return only single image in default
51+
// mode
52+
final String base64Encoded = generateImage(generateRequest, prompt);
53+
final Image image = new Image(null, base64Encoded);
54+
final ImageGeneration imageGeneration = new ImageGeneration(image);
55+
return new ImageResponse(List.of(imageGeneration), new ImageResponseMetadata());
56+
}
57+
58+
private String generateImage(GenerateRequest generateRequest, ImagePrompt prompt) {
59+
final String mimeType = prompt.getOptions().getResponseFormat();
60+
switch (mimeType) {
61+
case "application/json" -> {
62+
return new String(this.imageGenApi.generate(generateRequest, prompt.getOptions().getResponseFormat()));
63+
}
64+
default -> {
65+
byte[] bytes = this.imageGenApi.generate(generateRequest, prompt.getOptions().getResponseFormat());
66+
return Base64.getEncoder().encodeToString(bytes);
67+
}
68+
}
69+
}
70+
71+
private GenerateRequest createGenerateRequest(List<ImageMessage> promptInstructs,
72+
GenerateParameters generateParameters) {
73+
final GenerateRequest request = new GenerateRequest();
74+
final List<String> instructions = promptInstructs.stream().map(ImageMessage::getText).toList();
75+
76+
request.setParameters(generateParameters);
77+
request.setInputs(instructions);
78+
return request;
79+
}
80+
81+
private GenerateParameters createGenerateParameters(ImageOptions options) {
82+
final GenerateParameters params = new GenerateParameters();
83+
params.setWidth(options.getWidth());
84+
params.setHeight(options.getHeight());
85+
params.setNumImagesPerPrompt(options.getN());
86+
87+
if (options instanceof HuggingfaceImageOptions hfImageOptions) {
88+
params.setClipSkip(hfImageOptions.getClipSkip());
89+
params.setGuidanceScale(hfImageOptions.getGuidanceScale());
90+
params.setNumInferenceSteps(hfImageOptions.getNumInferenceSteps());
91+
params.setNegativePrompt(List.of(hfImageOptions.getNegativePrompt()));
92+
}
93+
return params;
94+
}
95+
96+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package org.springframework.ai.huggingface.text2image;
2+
3+
import org.springframework.ai.image.ImageOptions;
4+
5+
public class HuggingfaceImageOptions implements ImageOptions {
6+
7+
private Integer numImagesPerPrompt;
8+
9+
private String model;
10+
11+
private Integer width;
12+
13+
private Integer height;
14+
15+
private String responseFormat;
16+
17+
private String negativePrompt;
18+
19+
private Float sigmaItems;
20+
21+
private Integer timestepItems;
22+
23+
private Integer clipSkip;
24+
25+
private Float guidanceScale;
26+
27+
private Integer numInferenceSteps;
28+
29+
@Override
30+
public Integer getN() {
31+
return numImagesPerPrompt;
32+
}
33+
34+
public void setN(Integer numImagesPerPrompt) {
35+
this.numImagesPerPrompt = numImagesPerPrompt;
36+
}
37+
38+
@Override
39+
public String getModel() {
40+
return model;
41+
}
42+
43+
public void setModel(String model) {
44+
this.model = model;
45+
}
46+
47+
@Override
48+
public Integer getWidth() {
49+
return width;
50+
}
51+
52+
public void setWidth(Integer width) {
53+
this.width = width;
54+
}
55+
56+
@Override
57+
public Integer getHeight() {
58+
return height;
59+
}
60+
61+
public void setHeight(Integer height) {
62+
this.height = height;
63+
}
64+
65+
@Override
66+
public String getResponseFormat() {
67+
return responseFormat;
68+
}
69+
70+
public void setResponseFormat(String responseFormat) {
71+
this.responseFormat = responseFormat;
72+
}
73+
74+
public String getNegativePrompt() {
75+
return negativePrompt;
76+
}
77+
78+
public void setNegativePrompt(String negativePrompt) {
79+
this.negativePrompt = negativePrompt;
80+
}
81+
82+
public Float getSigmaItems() {
83+
return sigmaItems;
84+
}
85+
86+
public void setSigmaItems(Float sigmaItems) {
87+
this.sigmaItems = sigmaItems;
88+
}
89+
90+
public Integer getTimestepItems() {
91+
return timestepItems;
92+
}
93+
94+
public void setTimestepItems(Integer timestepItems) {
95+
this.timestepItems = timestepItems;
96+
}
97+
98+
public Integer getClipSkip() {
99+
return clipSkip;
100+
}
101+
102+
public void setClipSkip(Integer clipSkip) {
103+
this.clipSkip = clipSkip;
104+
}
105+
106+
public Float getGuidanceScale() {
107+
return guidanceScale;
108+
}
109+
110+
public void setGuidanceScale(Float guidanceScale) {
111+
this.guidanceScale = guidanceScale;
112+
}
113+
114+
public Integer getNumInferenceSteps() {
115+
return numInferenceSteps;
116+
}
117+
118+
public void setNumInferenceSteps(Integer numInferenceSteps) {
119+
this.numInferenceSteps = numInferenceSteps;
120+
}
121+
122+
public static class Builder {
123+
124+
private final HuggingfaceImageOptions options = new HuggingfaceImageOptions();
125+
126+
public Builder builder() {
127+
return new Builder();
128+
}
129+
130+
public Builder withNumImagesPerPrompt(Integer numImagesPerPrompt) {
131+
options.setN(numImagesPerPrompt);
132+
return this;
133+
}
134+
135+
public Builder withModel(String model) {
136+
options.setModel(model);
137+
return this;
138+
}
139+
140+
public Builder withResponseFormat(String responseFormat) {
141+
options.setResponseFormat(responseFormat);
142+
return this;
143+
}
144+
145+
public Builder withWidth(Integer width) {
146+
options.setWidth(width);
147+
return this;
148+
}
149+
150+
public Builder withHeight(Integer height) {
151+
options.setHeight(height);
152+
return this;
153+
}
154+
155+
public Builder withNegativePrompt(String negativePrompt) {
156+
options.setNegativePrompt(negativePrompt);
157+
return this;
158+
}
159+
160+
public Builder withSigmaItems(Float sigmaItems) {
161+
options.setSigmaItems(sigmaItems);
162+
return this;
163+
}
164+
165+
public Builder withTimestepItems(Integer timestepItems) {
166+
options.setTimestepItems(timestepItems);
167+
return this;
168+
}
169+
170+
public Builder withClipSkip(Integer clipSkip) {
171+
options.setClipSkip(clipSkip);
172+
return this;
173+
}
174+
175+
public Builder withGuidanceScale(Float guidanceScale) {
176+
options.setGuidanceScale(guidanceScale);
177+
return this;
178+
}
179+
180+
public Builder withNumInferenceSteps(Integer numInferenceSteps) {
181+
options.setNumInferenceSteps(numInferenceSteps);
182+
return this;
183+
}
184+
185+
public HuggingfaceImageOptions build() {
186+
return options;
187+
}
188+
189+
}
190+
191+
}

0 commit comments

Comments
 (0)