messages, HuggingfaceImageOptions imageModelOptions) {
+ super(messages, imageModelOptions);
+ }
+
+ public HuggingfaceImagePrompt(ImageMessage imageMessage, HuggingfaceImageOptions imageOptions) {
+ this(Collections.singletonList(imageMessage), imageOptions);
+ }
+
+ public HuggingfaceImagePrompt(String instructions, HuggingfaceImageOptions imageOptions) {
+ this(new ImageMessage(instructions), imageOptions);
+ }
+
+ public HuggingfaceImagePrompt(String instructions) {
+ this(new ImageMessage(instructions), new HuggingfaceImageOptions());
+ }
+
+}
diff --git a/models/spring-ai-huggingface/src/main/resources/openapi-imagegen.json b/models/spring-ai-huggingface/src/main/resources/openapi-imagegen.json
new file mode 100644
index 00000000000..f3df0d9f8d5
--- /dev/null
+++ b/models/spring-ai-huggingface/src/main/resources/openapi-imagegen.json
@@ -0,0 +1,249 @@
+{
+ "openapi": "3.0.3",
+ "info": {
+ "title": "Image Generation Inference",
+ "description": "Image Generation Webserver",
+ "contact": {
+ "name": "Denis Lobo"
+ },
+ "license": {
+ "name": "Apache 2.0",
+ "url": "https://www.apache.org/licenses/LICENSE-2.0"
+ },
+ "version": "1.0.2"
+ },
+ "paths": {
+ "/": {
+ "post": {
+ "tags": [
+ "Image Generation Inference"
+ ],
+ "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
+ "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
+ "operationId": "generate",
+ "parameters": [
+ {
+ "name": "Accept",
+ "in": "header",
+ "required": false,
+ "schema": {
+ "type": "string",
+ "enum": [
+ "application/json",
+ "image/png",
+ "image/jpeg",
+ "image/webp",
+ "image/gif",
+ "image/tiff",
+ "image/bmp"
+ ]
+ }
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/GenerateRequest"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Generated image",
+ "content": {
+ "application/json": {
+ "schema": {
+ "type": "string",
+ "format": "byte"
+ }
+ },
+ "image/png": {
+ "schema": {
+ "type": "string",
+ "format": "byte"
+ }
+ },
+ "image/jpeg": {
+ "schema": {
+ "type": "string",
+ "format": "byte"
+ }
+ },
+ "image/bmp": {
+ "schema": {
+ "type": "string",
+ "format": "byte"
+ }
+ },
+ "image/gif": {
+ "schema": {
+ "type": "string",
+ "format": "byte"
+ }
+ },
+ "image/tiff": {
+ "schema": {
+ "type": "string",
+ "format": "byte"
+ }
+ },
+ "image/webp": {
+ "schema": {
+ "type": "string",
+ "format": "byte"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Input validation error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Input validation error"
+ }
+ }
+ }
+ },
+ "424": {
+ "description": "Generation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Request failed during generation"
+ }
+ }
+ }
+ },
+ "429": {
+ "description": "Model is overloaded",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Model is overloaded"
+ }
+ }
+ }
+ },
+ "500": {
+ "description": "Incomplete generation",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/ErrorResponse"
+ },
+ "example": {
+ "error": "Incomplete generation"
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ "components": {
+ "schemas": {
+ "ErrorResponse": {
+ "type": "object",
+ "required": [
+ "error",
+ "error_type"
+ ],
+ "properties": {
+ "error": {
+ "type": "string"
+ },
+ "error_type": {
+ "type": "string"
+ }
+ }
+ },
+ "GenerateParameters": {
+ "type": "object",
+ "properties": {
+ "height": {
+ "type": "integer",
+ "default": "null",
+ "example": 1,
+ "nullable": true,
+ "minimum": 0.0,
+ "exclusiveMinimum": 0.0
+ },
+ "width": {
+ "type": "integer",
+ "default": "null",
+ "example": 1,
+ "nullable": true,
+ "minimum": 0.0,
+ "exclusiveMinimum": 0.0
+ },
+ "num_inference_steps": {
+ "type": "integer",
+ "default": 50,
+ "example": 1,
+ "nullable": true,
+ "minimum": 0.0,
+ "exclusiveMinimum": 0.0
+ },
+ "guidance_scale": {
+ "type": "number",
+ "format": "float",
+ "default": 7.5
+ },
+ "negative_prompt": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "num_images_per_prompt": {
+ "type": "integer",
+ "default": 1,
+ "example": 1
+ },
+ "clip_skip": {
+ "type": "integer",
+ "default": "null",
+ "example": 1
+ }
+ }
+ },
+ "GenerateRequest": {
+ "type": "object",
+ "required": [
+ "inputs"
+ ],
+ "properties": {
+ "inputs": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "parameters": {
+ "$ref": "#/components/schemas/GenerateParameters"
+ }
+ }
+ }
+ }
+ },
+ "tags": [
+ {
+ "name": "Text to Image Generation Inference",
+ "description": "Hugging Face Text To Image Inference API"
+ }
+ ]
+}
diff --git a/models/spring-ai-huggingface/src/main/resources/swagger-codegen/templates/Java/libraries/resttemplate/api.mustache b/models/spring-ai-huggingface/src/main/resources/swagger-codegen/templates/Java/libraries/resttemplate/api.mustache
new file mode 100644
index 00000000000..37b687cf0ff
--- /dev/null
+++ b/models/spring-ai-huggingface/src/main/resources/swagger-codegen/templates/Java/libraries/resttemplate/api.mustache
@@ -0,0 +1,170 @@
+package {{package}};
+
+import {{invokerPackage}}.ApiClient;
+
+{{#imports}}import {{import}};
+{{/imports}}
+
+{{^fullJavaUtil}}import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;{{/fullJavaUtil}}
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Component;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
+import org.springframework.web.client.RestClientException;
+import org.springframework.web.client.HttpClientErrorException;
+import org.springframework.web.util.UriComponentsBuilder;
+import org.springframework.core.ParameterizedTypeReference;
+import org.springframework.core.io.FileSystemResource;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
+import org.springframework.http.ResponseEntity;
+
+{{>generatedAnnotation}}
+@Component("{{package}}.{{classname}}")
+{{#operations}}
+public class {{classname}} {
+ private ApiClient {{localVariablePrefix}}apiClient;
+
+ public {{classname}}() {
+ this(new ApiClient());
+ }
+
+ @Autowired
+ public {{classname}}(ApiClient apiClient) {
+ this.{{localVariablePrefix}}apiClient = apiClient;
+ }
+
+ public ApiClient getApiClient() {
+ return {{localVariablePrefix}}apiClient;
+ }
+
+ public void setApiClient(ApiClient apiClient) {
+ this.{{localVariablePrefix}}apiClient = apiClient;
+ }
+
+ {{#operation}}
+ {{#contents}}
+ /**
+ * {{summary}}
+ * {{notes}}
+ {{#responses}}
+ * {{code}}{{#message}} - {{message}}{{/message}}
+ {{/responses}}
+ {{#parameters}}
+ * @param {{paramName}} {{description}}{{#required}} (required){{/required}}{{^required}} (optional{{#defaultValue}}, default to {{{.}}}{{/defaultValue}}){{/required}}
+ {{/parameters}}
+ {{#returnType}}
+ * @return {{returnType}}
+ {{/returnType}}
+ * @throws RestClientException if an error occurs while attempting to invoke the API
+ {{#externalDocs}}
+ * {{description}}
+ * @see {{summary}} Documentation
+ {{/externalDocs}}
+ */
+ {{#isDeprecated}}
+ @Deprecated
+ {{/isDeprecated}}
+ public {{#returnType}}{{{returnType}}} {{/returnType}}{{^returnType}}void {{/returnType}}{{operationId}}({{#parameters}}{{{dataType}}} {{paramName}}{{#hasMore}}, {{/hasMore}}{{/parameters}}) throws RestClientException {
+ {{#returnType}}
+ return {{operationId}}WithHttpInfo({{#parameters}}{{paramName}}{{#hasMore}}, {{/hasMore}}{{/parameters}}).getBody();
+ {{/returnType}}
+ {{^returnType}}
+ {{operationId}}WithHttpInfo({{#parameters}}{{paramName}}{{#hasMore}}, {{/hasMore}}{{/parameters}});
+ {{/returnType}}
+ }
+
+ /**
+ * {{summary}}
+ * {{notes}}
+ {{#responses}}
+ *
{{code}}{{#message}} - {{message}}{{/message}}
+ {{/responses}}
+ {{#parameters}}
+ * @param {{paramName}} {{description}}{{#required}} (required){{/required}}{{^required}} (optional{{#defaultValue}}, default to {{{.}}}{{/defaultValue}}){{/required}}
+ {{/parameters}}
+ * @return ResponseEntity<{{#returnType}}{{returnType}}{{/returnType}}{{^returnType}}Void{{/returnType}}>
+ * @throws RestClientException if an error occurs while attempting to invoke the API
+ {{#externalDocs}}
+ * {{description}}
+ * @see {{summary}} Documentation
+ {{/externalDocs}}
+ */
+ {{#isDeprecated}}
+ @Deprecated
+ {{/isDeprecated}}
+ public ResponseEntity<{{#returnType}}{{{returnType}}}{{/returnType}}{{^returnType}}Void{{/returnType}}> {{operationId}}WithHttpInfo({{#parameters}}{{{dataType}}} {{paramName}}{{#hasMore}}, {{/hasMore}}{{/parameters}}) throws RestClientException {
+ Object {{localVariablePrefix}}postBody = {{^isForm}}{{#bodyParam}}{{paramName}}{{/bodyParam}}{{^bodyParam}}null{{/bodyParam}}{{/isForm}}{{#isForm}}null{{/isForm}};
+ {{#parameters}}
+ {{#required}}
+ // verify the required parameter '{{paramName}}' is set
+ if ({{paramName}} == null) {
+ throw new HttpClientErrorException(HttpStatus.BAD_REQUEST, "Missing the required parameter '{{paramName}}' when calling {{operationId}}");
+ }
+ {{/required}}
+ {{/parameters}}
+ {{#hasPathParams}}
+ // create path and map variables
+ final Map uriVariables = new HashMap();
+ {{#pathParams}}
+ uriVariables.put("{{baseName}}", {{{paramName}}});
+ {{/pathParams}}
+ {{/hasPathParams}}
+ String {{localVariablePrefix}}path = UriComponentsBuilder.fromPath("{{{path}}}"){{#hasPathParams}}.buildAndExpand(uriVariables){{/hasPathParams}}{{^hasPathParams}}.build(){{/hasPathParams}}.toUriString();
+
+ final MultiValueMap {{localVariablePrefix}}queryParams = new LinkedMultiValueMap();
+ final HttpHeaders {{localVariablePrefix}}headerParams = new HttpHeaders();
+ final MultiValueMap {{localVariablePrefix}}formParams = new LinkedMultiValueMap();
+ {{#hasQueryParams}}
+ {{#queryParams}}
+ {{localVariablePrefix}}queryParams.putAll({{localVariablePrefix}}apiClient.parameterToMultiValueMap({{#collectionFormat}}ApiClient.CollectionFormat.valueOf("{{{collectionFormat}}}".toUpperCase()){{/collectionFormat}}{{^collectionFormat}}null{{/collectionFormat}}, "{{baseName}}", {{paramName}}));
+ {{/queryParams}}
+ {{/hasQueryParams}}
+
+ {{#hasHeaderParams}}
+ {{#headerParams}}
+ if ({{paramName}} != null)
+ {{localVariablePrefix}}headerParams.add("{{baseName}}", {{localVariablePrefix}}apiClient.parameterToString({{paramName}}));
+ {{/headerParams}}
+ {{/hasHeaderParams}}
+ {{#hasFormParams}}
+ {{#isForm}}
+ {{#formParams}}
+ if ({{paramName}} != null)
+ {{localVariablePrefix}}formParams.add("{{baseName}}", {{#is this 'binary'}}new FileSystemResource({{paramName}}){{/is}}{{#isNot this 'binary'}}{{paramName}}{{/isNot}});
+ {{/formParams}}
+ {{/isForm}}
+ {{/hasFormParams}}
+
+ {{^hasHeaderParams}}
+ final String[] {{localVariablePrefix}}accepts = { {{#hasProduces}}
+ {{#produces}}"{{mediaType}}"{{#hasMore}}, {{/hasMore}}{{/produces}}
+ {{/hasProduces}} };
+ final List {{localVariablePrefix}}accept = {{localVariablePrefix}}apiClient.selectHeaderAccept({{localVariablePrefix}}accepts);
+ {{/hasHeaderParams}}
+ final String[] {{localVariablePrefix}}contentTypes = { {{#hasConsumes}}
+ {{#consumes}}"{{mediaType}}"{{#hasMore}}, {{/hasMore}}{{/consumes}}
+ {{/hasConsumes}} };
+ final MediaType {{localVariablePrefix}}contentType = {{localVariablePrefix}}apiClient.selectHeaderContentType({{localVariablePrefix}}contentTypes);
+
+ String[] {{localVariablePrefix}}authNames = new String[] { {{#authMethods}}"{{name}}"{{#hasMore}}, {{/hasMore}}{{/authMethods}} };
+
+ {{#returnType}}ParameterizedTypeReference<{{{returnType}}}> {{localVariablePrefix}}returnType = new ParameterizedTypeReference<{{{returnType}}}>() {};{{/returnType}}{{^returnType}}ParameterizedTypeReference {{localVariablePrefix}}returnType = new ParameterizedTypeReference() {};{{/returnType}}
+
+ {{^hasHeaderParams}}
+ return {{localVariablePrefix}}apiClient.invokeAPI({{localVariablePrefix}}path, HttpMethod.{{httpMethod}}, {{localVariablePrefix}}queryParams, {{localVariablePrefix}}postBody, {{localVariablePrefix}}headerParams, {{localVariablePrefix}}formParams, {{localVariablePrefix}}accept, {{localVariablePrefix}}contentType, {{localVariablePrefix}}authNames, {{localVariablePrefix}}returnType);
+ {{/hasHeaderParams}}
+ {{#hasHeaderParams}}
+ return {{localVariablePrefix}}apiClient.invokeAPI({{localVariablePrefix}}path, HttpMethod.{{httpMethod}}, {{localVariablePrefix}}queryParams, {{localVariablePrefix}}postBody, {{localVariablePrefix}}headerParams, {{localVariablePrefix}}formParams, null, {{localVariablePrefix}}contentType, {{localVariablePrefix}}authNames, {{localVariablePrefix}}returnType);
+ {{/hasHeaderParams}}
+ }
+ {{/contents}}
+ {{/operation}}
+}
+{{/operations}}
diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java
index 5f933a09c8c..8143f24b6eb 100644
--- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java
+++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java
@@ -36,4 +36,16 @@ public HuggingfaceChatModel huggingfaceChatModel() {
return huggingfaceChatModel;
}
+ @Bean
+ public HuggingfaceImageModel huggingfaceImageModel() {
+ String apiKey = System.getenv("HUGGINGFACE_API_KEY");
+ if (!StringUtils.hasText(apiKey)) {
+ throw new IllegalArgumentException(
+ "You must provide an API key. Put it in an environment variable under the name HUGGINGFACE_API_KEY");
+ }
+ HuggingfaceImageModel huggingfaceImageModel = new HuggingfaceImageModel(apiKey,
+ System.getenv("HUGGINGFACE_CHAT_URL"));
+ return huggingfaceImageModel;
+ }
+
}
diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/TextToImageClientIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/TextToImageClientIT.java
new file mode 100644
index 00000000000..d7b97a0282e
--- /dev/null
+++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/TextToImageClientIT.java
@@ -0,0 +1,33 @@
+package org.springframework.ai.huggingface.client;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.springframework.ai.huggingface.HuggingfaceImageModel;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.test.context.SpringBootTest;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+@SpringBootTest
+@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+")
+@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_TEXT_TO_IMAGE_URL", matches = ".+")
+public class TextToImageClientIT {
+
+ @Autowired
+ protected HuggingfaceImageModel huggingfaceImageModel;
+
+ @Test
+ void helloWorldCompletion() {
+ String textToImageInstruct = """
+ A cat touching a mirror and seeing its reflection for the first time.
+ """;
+ ImagePrompt prompt = new ImagePrompt(textToImageInstruct);
+ ImageResponse response = huggingfaceImageModel.call(prompt);
+ assertThat(response.getResult().getOutput()).isNotNull();
+ assertThat(response.getResult().getMetadata()).isNotNull();
+ assertThat(response.getResult().getOutput().getB64Json()).isNotEmpty();
+ }
+
+}