From b7cf095902de4561437c77b68bd578db834f1b44 Mon Sep 17 00:00:00 2001 From: Denis Lobo Date: Sun, 4 Aug 2024 17:06:05 +0200 Subject: [PATCH 1/2] add hugging face text to image integration --- models/spring-ai-huggingface/pom.xml | 27 +- .../ai/huggingface/HuggingfaceChatModel.java | 18 +- .../ai/huggingface/HuggingfaceImageModel.java | 96 +++++++ .../text2image/HuggingfaceImageOptions.java | 191 ++++++++++++++ .../text2image/HuggingfaceImagePrompt.java | 31 +++ .../src/main/resources/openapi-imagegen.json | 249 ++++++++++++++++++ .../Java/libraries/resttemplate/api.mustache | 170 ++++++++++++ .../HuggingfaceTestConfiguration.java | 12 + .../client/TextToImageClientIT.java | 33 +++ 9 files changed, 818 insertions(+), 9 deletions(-) create mode 100644 models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java create mode 100644 models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java create mode 100644 models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImagePrompt.java create mode 100644 models/spring-ai-huggingface/src/main/resources/openapi-imagegen.json create mode 100644 models/spring-ai-huggingface/src/main/resources/swagger-codegen/templates/Java/libraries/resttemplate/api.mustache create mode 100644 models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/TextToImageClientIT.java diff --git a/models/spring-ai-huggingface/pom.xml b/models/spring-ai-huggingface/pom.xml index 9a1d487dbb4..72b25db456d 100644 --- a/models/spring-ai-huggingface/pom.xml +++ b/models/spring-ai-huggingface/pom.xml @@ -93,6 +93,7 @@ 3.0.46 + generate-chat-api generate @@ -101,7 +102,7 @@ java resttemplate org.springframework.ai.huggingface.api - org.springframework.ai.huggingface.model + org.springframework.ai.huggingface.model.chat org.springframework.ai.huggingface.invoker false false @@ -113,6 +114,30 @@ + + generate-imagegen-api + + generate + + + ${project.basedir}/src/main/resources/openapi-imagegen.json + java + resttemplate + org.springframework.ai.huggingface.api + org.springframework.ai.huggingface.model.imagegen + org.springframework.ai.huggingface.invoker + false + false + + src/main/resources/swagger-codegen/templates/Java + + src/main/java + java8 + + true + + + diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java index aa222f4f320..acb23844c75 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java @@ -16,10 +16,6 @@ package org.springframework.ai.huggingface; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -32,10 +28,12 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.huggingface.api.TextGenerationInferenceApi; import org.springframework.ai.huggingface.invoker.ApiClient; -import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails; -import org.springframework.ai.huggingface.model.CompatGenerateRequest; -import org.springframework.ai.huggingface.model.GenerateParameters; -import org.springframework.ai.huggingface.model.GenerateResponse; + +import org.springframework.ai.huggingface.model.chat.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; /** * An implementation of {@link ChatModel} that interfaces with HuggingFace Inference @@ -112,6 +110,10 @@ public ChatResponse call(Prompt prompt) { return new ChatResponse(generations); } + public Info info() { + return this.textGenApi.getModelInfo(); + } + /** * Gets the maximum number of new tokens to be generated. * @return The maximum number of new tokens. diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java new file mode 100644 index 00000000000..2480343ded1 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java @@ -0,0 +1,96 @@ +package org.springframework.ai.huggingface; + +import org.springframework.ai.huggingface.api.ImageGenerationInferenceApi; +import org.springframework.ai.huggingface.text2image.HuggingfaceImageOptions; +import org.springframework.ai.huggingface.invoker.ApiClient; +import org.springframework.ai.huggingface.model.imagegen.GenerateParameters; +import org.springframework.ai.huggingface.model.imagegen.GenerateRequest; +import org.springframework.ai.image.*; + +import java.util.Base64; +import java.util.List; + +/** + * An implementation of {@link ImageModel} that interfaces with HuggingFace Inference + * Endpoints for text-to-image generation. + * + * @author Denis Lobo + */ +public class HuggingfaceImageModel implements ImageModel { + + /** + * Token required for authenticating with the HuggingFace Inference API. + */ + private final String apiToken; + + /** + * Client for making API calls. + */ + private ApiClient apiClient = new ApiClient(); + + private ImageGenerationInferenceApi imageGenApi = new ImageGenerationInferenceApi(); + + /** + * Constructs a new HuggingfaceImageModel with the specified API token and base path. + * @param apiToken The API token for HuggingFace. + * @param basePath The base path for API requests. + */ + public HuggingfaceImageModel(final String apiToken, String basePath) { + this.apiToken = apiToken; + this.apiClient.setBasePath(basePath); + this.apiClient.addDefaultHeader("Authorization", "Bearer " + this.apiToken); + this.imageGenApi.setApiClient(this.apiClient); + } + + @Override + public ImageResponse call(ImagePrompt prompt) { + final GenerateParameters generateParameters = createGenerateParameters(prompt.getOptions()); + final GenerateRequest generateRequest = createGenerateRequest(prompt.getInstructions(), generateParameters); + + // huggingface eps with text-to-image models return only single image in default + // mode + final String base64Encoded = generateImage(generateRequest, prompt); + final Image image = new Image(null, base64Encoded); + final ImageGeneration imageGeneration = new ImageGeneration(image); + return new ImageResponse(List.of(imageGeneration), new ImageResponseMetadata()); + } + + private String generateImage(GenerateRequest generateRequest, ImagePrompt prompt) { + final String mimeType = prompt.getOptions().getResponseFormat(); + switch (mimeType) { + case "application/json" -> { + return new String(this.imageGenApi.generate(generateRequest, prompt.getOptions().getResponseFormat())); + } + default -> { + byte[] bytes = this.imageGenApi.generate(generateRequest, prompt.getOptions().getResponseFormat()); + return Base64.getEncoder().encodeToString(bytes); + } + } + } + + private GenerateRequest createGenerateRequest(List promptInstructs, + GenerateParameters generateParameters) { + final GenerateRequest request = new GenerateRequest(); + final List instructions = promptInstructs.stream().map(ImageMessage::getText).toList(); + + request.setParameters(generateParameters); + request.setInputs(instructions); + return request; + } + + private GenerateParameters createGenerateParameters(ImageOptions options) { + final GenerateParameters params = new GenerateParameters(); + params.setWidth(options.getWidth()); + params.setHeight(options.getHeight()); + params.setNumImagesPerPrompt(options.getN()); + + if (options instanceof HuggingfaceImageOptions hfImageOptions) { + params.setClipSkip(hfImageOptions.getClipSkip()); + params.setGuidanceScale(hfImageOptions.getGuidanceScale()); + params.setNumInferenceSteps(hfImageOptions.getNumInferenceSteps()); + params.setNegativePrompt(List.of(hfImageOptions.getNegativePrompt())); + } + return params; + } + +} diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java new file mode 100644 index 00000000000..a848a6d6f91 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java @@ -0,0 +1,191 @@ +package org.springframework.ai.huggingface.text2image; + +import org.springframework.ai.image.ImageOptions; + +public class HuggingfaceImageOptions implements ImageOptions { + + private Integer numImagesPerPrompt; + + private String model; + + private Integer width; + + private Integer height; + + private String responseFormat; + + private String negativePrompt; + + private Float sigmaItems; + + private Integer timestepItems; + + private Integer clipSkip; + + private Float guidanceScale; + + private Integer numInferenceSteps; + + @Override + public Integer getN() { + return numImagesPerPrompt; + } + + public void setN(Integer numImagesPerPrompt) { + this.numImagesPerPrompt = numImagesPerPrompt; + } + + @Override + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Integer getWidth() { + return width; + } + + public void setWidth(Integer width) { + this.width = width; + } + + @Override + public Integer getHeight() { + return height; + } + + public void setHeight(Integer height) { + this.height = height; + } + + @Override + public String getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + public String getNegativePrompt() { + return negativePrompt; + } + + public void setNegativePrompt(String negativePrompt) { + this.negativePrompt = negativePrompt; + } + + public Float getSigmaItems() { + return sigmaItems; + } + + public void setSigmaItems(Float sigmaItems) { + this.sigmaItems = sigmaItems; + } + + public Integer getTimestepItems() { + return timestepItems; + } + + public void setTimestepItems(Integer timestepItems) { + this.timestepItems = timestepItems; + } + + public Integer getClipSkip() { + return clipSkip; + } + + public void setClipSkip(Integer clipSkip) { + this.clipSkip = clipSkip; + } + + public Float getGuidanceScale() { + return guidanceScale; + } + + public void setGuidanceScale(Float guidanceScale) { + this.guidanceScale = guidanceScale; + } + + public Integer getNumInferenceSteps() { + return numInferenceSteps; + } + + public void setNumInferenceSteps(Integer numInferenceSteps) { + this.numInferenceSteps = numInferenceSteps; + } + + public static class Builder { + + private final HuggingfaceImageOptions options = new HuggingfaceImageOptions(); + + public Builder builder() { + return new Builder(); + } + + public Builder withNumImagesPerPrompt(Integer numImagesPerPrompt) { + options.setN(numImagesPerPrompt); + return this; + } + + public Builder withModel(String model) { + options.setModel(model); + return this; + } + + public Builder withResponseFormat(String responseFormat) { + options.setResponseFormat(responseFormat); + return this; + } + + public Builder withWidth(Integer width) { + options.setWidth(width); + return this; + } + + public Builder withHeight(Integer height) { + options.setHeight(height); + return this; + } + + public Builder withNegativePrompt(String negativePrompt) { + options.setNegativePrompt(negativePrompt); + return this; + } + + public Builder withSigmaItems(Float sigmaItems) { + options.setSigmaItems(sigmaItems); + return this; + } + + public Builder withTimestepItems(Integer timestepItems) { + options.setTimestepItems(timestepItems); + return this; + } + + public Builder withClipSkip(Integer clipSkip) { + options.setClipSkip(clipSkip); + return this; + } + + public Builder withGuidanceScale(Float guidanceScale) { + options.setGuidanceScale(guidanceScale); + return this; + } + + public Builder withNumInferenceSteps(Integer numInferenceSteps) { + options.setNumInferenceSteps(numInferenceSteps); + return this; + } + + public HuggingfaceImageOptions build() { + return options; + } + + } + +} diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImagePrompt.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImagePrompt.java new file mode 100644 index 00000000000..7fbb7855097 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImagePrompt.java @@ -0,0 +1,31 @@ +package org.springframework.ai.huggingface.text2image; + +import org.springframework.ai.image.ImageMessage; +import org.springframework.ai.image.ImagePrompt; + +import java.util.Collections; +import java.util.List; + +public class HuggingfaceImagePrompt extends ImagePrompt { + + public HuggingfaceImagePrompt(List messages) { + this(messages, null); + } + + public HuggingfaceImagePrompt(List messages, HuggingfaceImageOptions imageModelOptions) { + super(messages, imageModelOptions); + } + + public HuggingfaceImagePrompt(ImageMessage imageMessage, HuggingfaceImageOptions imageOptions) { + this(Collections.singletonList(imageMessage), imageOptions); + } + + public HuggingfaceImagePrompt(String instructions, HuggingfaceImageOptions imageOptions) { + this(new ImageMessage(instructions), imageOptions); + } + + public HuggingfaceImagePrompt(String instructions) { + this(new ImageMessage(instructions), new HuggingfaceImageOptions()); + } + +} diff --git a/models/spring-ai-huggingface/src/main/resources/openapi-imagegen.json b/models/spring-ai-huggingface/src/main/resources/openapi-imagegen.json new file mode 100644 index 00000000000..f3df0d9f8d5 --- /dev/null +++ b/models/spring-ai-huggingface/src/main/resources/openapi-imagegen.json @@ -0,0 +1,249 @@ +{ + "openapi": "3.0.3", + "info": { + "title": "Image Generation Inference", + "description": "Image Generation Webserver", + "contact": { + "name": "Denis Lobo" + }, + "license": { + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + }, + "version": "1.0.2" + }, + "paths": { + "/": { + "post": { + "tags": [ + "Image Generation Inference" + ], + "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", + "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`", + "operationId": "generate", + "parameters": [ + { + "name": "Accept", + "in": "header", + "required": false, + "schema": { + "type": "string", + "enum": [ + "application/json", + "image/png", + "image/jpeg", + "image/webp", + "image/gif", + "image/tiff", + "image/bmp" + ] + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Generated image", + "content": { + "application/json": { + "schema": { + "type": "string", + "format": "byte" + } + }, + "image/png": { + "schema": { + "type": "string", + "format": "byte" + } + }, + "image/jpeg": { + "schema": { + "type": "string", + "format": "byte" + } + }, + "image/bmp": { + "schema": { + "type": "string", + "format": "byte" + } + }, + "image/gif": { + "schema": { + "type": "string", + "format": "byte" + } + }, + "image/tiff": { + "schema": { + "type": "string", + "format": "byte" + } + }, + "image/webp": { + "schema": { + "type": "string", + "format": "byte" + } + } + } + }, + "422": { + "description": "Input validation error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Input validation error" + } + } + } + }, + "424": { + "description": "Generation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Request failed during generation" + } + } + } + }, + "429": { + "description": "Model is overloaded", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Model is overloaded" + } + } + } + }, + "500": { + "description": "Incomplete generation", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + }, + "example": { + "error": "Incomplete generation" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "ErrorResponse": { + "type": "object", + "required": [ + "error", + "error_type" + ], + "properties": { + "error": { + "type": "string" + }, + "error_type": { + "type": "string" + } + } + }, + "GenerateParameters": { + "type": "object", + "properties": { + "height": { + "type": "integer", + "default": "null", + "example": 1, + "nullable": true, + "minimum": 0.0, + "exclusiveMinimum": 0.0 + }, + "width": { + "type": "integer", + "default": "null", + "example": 1, + "nullable": true, + "minimum": 0.0, + "exclusiveMinimum": 0.0 + }, + "num_inference_steps": { + "type": "integer", + "default": 50, + "example": 1, + "nullable": true, + "minimum": 0.0, + "exclusiveMinimum": 0.0 + }, + "guidance_scale": { + "type": "number", + "format": "float", + "default": 7.5 + }, + "negative_prompt": { + "type": "array", + "items": { + "type": "string" + } + }, + "num_images_per_prompt": { + "type": "integer", + "default": 1, + "example": 1 + }, + "clip_skip": { + "type": "integer", + "default": "null", + "example": 1 + } + } + }, + "GenerateRequest": { + "type": "object", + "required": [ + "inputs" + ], + "properties": { + "inputs": { + "type": "array", + "items": { + "type": "string" + } + }, + "parameters": { + "$ref": "#/components/schemas/GenerateParameters" + } + } + } + } + }, + "tags": [ + { + "name": "Text to Image Generation Inference", + "description": "Hugging Face Text To Image Inference API" + } + ] +} diff --git a/models/spring-ai-huggingface/src/main/resources/swagger-codegen/templates/Java/libraries/resttemplate/api.mustache b/models/spring-ai-huggingface/src/main/resources/swagger-codegen/templates/Java/libraries/resttemplate/api.mustache new file mode 100644 index 00000000000..37b687cf0ff --- /dev/null +++ b/models/spring-ai-huggingface/src/main/resources/swagger-codegen/templates/Java/libraries/resttemplate/api.mustache @@ -0,0 +1,170 @@ +package {{package}}; + +import {{invokerPackage}}.ApiClient; + +{{#imports}}import {{import}}; +{{/imports}} + +{{^fullJavaUtil}}import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map;{{/fullJavaUtil}} + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.FileSystemResource; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; + +{{>generatedAnnotation}} +@Component("{{package}}.{{classname}}") +{{#operations}} +public class {{classname}} { + private ApiClient {{localVariablePrefix}}apiClient; + + public {{classname}}() { + this(new ApiClient()); + } + + @Autowired + public {{classname}}(ApiClient apiClient) { + this.{{localVariablePrefix}}apiClient = apiClient; + } + + public ApiClient getApiClient() { + return {{localVariablePrefix}}apiClient; + } + + public void setApiClient(ApiClient apiClient) { + this.{{localVariablePrefix}}apiClient = apiClient; + } + + {{#operation}} + {{#contents}} + /** + * {{summary}} + * {{notes}} + {{#responses}} + *

{{code}}{{#message}} - {{message}}{{/message}} + {{/responses}} + {{#parameters}} + * @param {{paramName}} {{description}}{{#required}} (required){{/required}}{{^required}} (optional{{#defaultValue}}, default to {{{.}}}{{/defaultValue}}){{/required}} + {{/parameters}} + {{#returnType}} + * @return {{returnType}} + {{/returnType}} + * @throws RestClientException if an error occurs while attempting to invoke the API + {{#externalDocs}} + * {{description}} + * @see {{summary}} Documentation + {{/externalDocs}} + */ + {{#isDeprecated}} + @Deprecated + {{/isDeprecated}} + public {{#returnType}}{{{returnType}}} {{/returnType}}{{^returnType}}void {{/returnType}}{{operationId}}({{#parameters}}{{{dataType}}} {{paramName}}{{#hasMore}}, {{/hasMore}}{{/parameters}}) throws RestClientException { + {{#returnType}} + return {{operationId}}WithHttpInfo({{#parameters}}{{paramName}}{{#hasMore}}, {{/hasMore}}{{/parameters}}).getBody(); + {{/returnType}} + {{^returnType}} + {{operationId}}WithHttpInfo({{#parameters}}{{paramName}}{{#hasMore}}, {{/hasMore}}{{/parameters}}); + {{/returnType}} + } + + /** + * {{summary}} + * {{notes}} + {{#responses}} + *

{{code}}{{#message}} - {{message}}{{/message}} + {{/responses}} + {{#parameters}} + * @param {{paramName}} {{description}}{{#required}} (required){{/required}}{{^required}} (optional{{#defaultValue}}, default to {{{.}}}{{/defaultValue}}){{/required}} + {{/parameters}} + * @return ResponseEntity<{{#returnType}}{{returnType}}{{/returnType}}{{^returnType}}Void{{/returnType}}> + * @throws RestClientException if an error occurs while attempting to invoke the API + {{#externalDocs}} + * {{description}} + * @see {{summary}} Documentation + {{/externalDocs}} + */ + {{#isDeprecated}} + @Deprecated + {{/isDeprecated}} + public ResponseEntity<{{#returnType}}{{{returnType}}}{{/returnType}}{{^returnType}}Void{{/returnType}}> {{operationId}}WithHttpInfo({{#parameters}}{{{dataType}}} {{paramName}}{{#hasMore}}, {{/hasMore}}{{/parameters}}) throws RestClientException { + Object {{localVariablePrefix}}postBody = {{^isForm}}{{#bodyParam}}{{paramName}}{{/bodyParam}}{{^bodyParam}}null{{/bodyParam}}{{/isForm}}{{#isForm}}null{{/isForm}}; + {{#parameters}} + {{#required}} + // verify the required parameter '{{paramName}}' is set + if ({{paramName}} == null) { + throw new HttpClientErrorException(HttpStatus.BAD_REQUEST, "Missing the required parameter '{{paramName}}' when calling {{operationId}}"); + } + {{/required}} + {{/parameters}} + {{#hasPathParams}} + // create path and map variables + final Map uriVariables = new HashMap(); + {{#pathParams}} + uriVariables.put("{{baseName}}", {{{paramName}}}); + {{/pathParams}} + {{/hasPathParams}} + String {{localVariablePrefix}}path = UriComponentsBuilder.fromPath("{{{path}}}"){{#hasPathParams}}.buildAndExpand(uriVariables){{/hasPathParams}}{{^hasPathParams}}.build(){{/hasPathParams}}.toUriString(); + + final MultiValueMap {{localVariablePrefix}}queryParams = new LinkedMultiValueMap(); + final HttpHeaders {{localVariablePrefix}}headerParams = new HttpHeaders(); + final MultiValueMap {{localVariablePrefix}}formParams = new LinkedMultiValueMap(); + {{#hasQueryParams}} + {{#queryParams}} + {{localVariablePrefix}}queryParams.putAll({{localVariablePrefix}}apiClient.parameterToMultiValueMap({{#collectionFormat}}ApiClient.CollectionFormat.valueOf("{{{collectionFormat}}}".toUpperCase()){{/collectionFormat}}{{^collectionFormat}}null{{/collectionFormat}}, "{{baseName}}", {{paramName}})); + {{/queryParams}} + {{/hasQueryParams}} + + {{#hasHeaderParams}} + {{#headerParams}} + if ({{paramName}} != null) + {{localVariablePrefix}}headerParams.add("{{baseName}}", {{localVariablePrefix}}apiClient.parameterToString({{paramName}})); + {{/headerParams}} + {{/hasHeaderParams}} + {{#hasFormParams}} + {{#isForm}} + {{#formParams}} + if ({{paramName}} != null) + {{localVariablePrefix}}formParams.add("{{baseName}}", {{#is this 'binary'}}new FileSystemResource({{paramName}}){{/is}}{{#isNot this 'binary'}}{{paramName}}{{/isNot}}); + {{/formParams}} + {{/isForm}} + {{/hasFormParams}} + + {{^hasHeaderParams}} + final String[] {{localVariablePrefix}}accepts = { {{#hasProduces}} + {{#produces}}"{{mediaType}}"{{#hasMore}}, {{/hasMore}}{{/produces}} + {{/hasProduces}} }; + final List {{localVariablePrefix}}accept = {{localVariablePrefix}}apiClient.selectHeaderAccept({{localVariablePrefix}}accepts); + {{/hasHeaderParams}} + final String[] {{localVariablePrefix}}contentTypes = { {{#hasConsumes}} + {{#consumes}}"{{mediaType}}"{{#hasMore}}, {{/hasMore}}{{/consumes}} + {{/hasConsumes}} }; + final MediaType {{localVariablePrefix}}contentType = {{localVariablePrefix}}apiClient.selectHeaderContentType({{localVariablePrefix}}contentTypes); + + String[] {{localVariablePrefix}}authNames = new String[] { {{#authMethods}}"{{name}}"{{#hasMore}}, {{/hasMore}}{{/authMethods}} }; + + {{#returnType}}ParameterizedTypeReference<{{{returnType}}}> {{localVariablePrefix}}returnType = new ParameterizedTypeReference<{{{returnType}}}>() {};{{/returnType}}{{^returnType}}ParameterizedTypeReference {{localVariablePrefix}}returnType = new ParameterizedTypeReference() {};{{/returnType}} + + {{^hasHeaderParams}} + return {{localVariablePrefix}}apiClient.invokeAPI({{localVariablePrefix}}path, HttpMethod.{{httpMethod}}, {{localVariablePrefix}}queryParams, {{localVariablePrefix}}postBody, {{localVariablePrefix}}headerParams, {{localVariablePrefix}}formParams, {{localVariablePrefix}}accept, {{localVariablePrefix}}contentType, {{localVariablePrefix}}authNames, {{localVariablePrefix}}returnType); + {{/hasHeaderParams}} + {{#hasHeaderParams}} + return {{localVariablePrefix}}apiClient.invokeAPI({{localVariablePrefix}}path, HttpMethod.{{httpMethod}}, {{localVariablePrefix}}queryParams, {{localVariablePrefix}}postBody, {{localVariablePrefix}}headerParams, {{localVariablePrefix}}formParams, null, {{localVariablePrefix}}contentType, {{localVariablePrefix}}authNames, {{localVariablePrefix}}returnType); + {{/hasHeaderParams}} + } + {{/contents}} + {{/operation}} +} +{{/operations}} diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java index 5f933a09c8c..8143f24b6eb 100644 --- a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/HuggingfaceTestConfiguration.java @@ -36,4 +36,16 @@ public HuggingfaceChatModel huggingfaceChatModel() { return huggingfaceChatModel; } + @Bean + public HuggingfaceImageModel huggingfaceImageModel() { + String apiKey = System.getenv("HUGGINGFACE_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name HUGGINGFACE_API_KEY"); + } + HuggingfaceImageModel huggingfaceImageModel = new HuggingfaceImageModel(apiKey, + System.getenv("HUGGINGFACE_CHAT_URL")); + return huggingfaceImageModel; + } + } diff --git a/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/TextToImageClientIT.java b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/TextToImageClientIT.java new file mode 100644 index 00000000000..d7b97a0282e --- /dev/null +++ b/models/spring-ai-huggingface/src/test/java/org/springframework/ai/huggingface/client/TextToImageClientIT.java @@ -0,0 +1,33 @@ +package org.springframework.ai.huggingface.client; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.huggingface.HuggingfaceImageModel; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_API_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "HUGGINGFACE_TEXT_TO_IMAGE_URL", matches = ".+") +public class TextToImageClientIT { + + @Autowired + protected HuggingfaceImageModel huggingfaceImageModel; + + @Test + void helloWorldCompletion() { + String textToImageInstruct = """ + A cat touching a mirror and seeing its reflection for the first time. + """; + ImagePrompt prompt = new ImagePrompt(textToImageInstruct); + ImageResponse response = huggingfaceImageModel.call(prompt); + assertThat(response.getResult().getOutput()).isNotNull(); + assertThat(response.getResult().getMetadata()).isNotNull(); + assertThat(response.getResult().getOutput().getB64Json()).isNotEmpty(); + } + +} From 00387720af08d77dc6ffc7e5f0bed8a1fc96b96d Mon Sep 17 00:00:00 2001 From: lobo Date: Sun, 27 Oct 2024 18:04:15 +0100 Subject: [PATCH 2/2] consolidate implementation with recent changes in upstream --- .../ai/huggingface/HuggingfaceChatModel.java | 4 +++ .../ai/huggingface/HuggingfaceImageModel.java | 22 +++++++++------ .../text2image/HuggingfaceImageOptions.java | 28 +++++++++++++++++++ 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java index acb23844c75..6b02e2fa1bc 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java @@ -34,6 +34,10 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import org.springframework.ai.huggingface.model.chat.AllOfGenerateResponseDetails; +import org.springframework.ai.huggingface.model.chat.GenerateParameters; +import org.springframework.ai.huggingface.model.chat.GenerateRequest; +import org.springframework.ai.huggingface.model.chat.GenerateResponse; /** * An implementation of {@link ChatModel} that interfaces with HuggingFace Inference diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java index 2480343ded1..21a831bd003 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java @@ -18,6 +18,8 @@ */ public class HuggingfaceImageModel implements ImageModel { + private final String APPLICATION_JSON = "application/json"; + /** * Token required for authenticating with the HuggingFace Inference API. */ @@ -47,8 +49,7 @@ public ImageResponse call(ImagePrompt prompt) { final GenerateParameters generateParameters = createGenerateParameters(prompt.getOptions()); final GenerateRequest generateRequest = createGenerateRequest(prompt.getInstructions(), generateParameters); - // huggingface eps with text-to-image models return only single image in default - // mode + // hf text-to-image endpoints return only a single image in default mode final String base64Encoded = generateImage(generateRequest, prompt); final Image image = new Image(null, base64Encoded); final ImageGeneration imageGeneration = new ImageGeneration(image); @@ -56,15 +57,20 @@ public ImageResponse call(ImagePrompt prompt) { } private String generateImage(GenerateRequest generateRequest, ImagePrompt prompt) { - final String mimeType = prompt.getOptions().getResponseFormat(); - switch (mimeType) { - case "application/json" -> { - return new String(this.imageGenApi.generate(generateRequest, prompt.getOptions().getResponseFormat())); + final String responseFormat = prompt.getOptions().getResponseFormat(); + final HuggingfaceImageOptions options = (HuggingfaceImageOptions) prompt.getOptions(); + switch (responseFormat) { + case "base64" -> { + return new String(this.imageGenApi.generate(generateRequest, APPLICATION_JSON)); } - default -> { - byte[] bytes = this.imageGenApi.generate(generateRequest, prompt.getOptions().getResponseFormat()); + case "bytes" -> { + byte[] bytes = this.imageGenApi.generate(generateRequest, options.getResponseMimeType()); return Base64.getEncoder().encodeToString(bytes); } + default -> { + throw new UnsupportedOperationException(String + .format("Unsupported response format: %s, should be 'base64' or 'bytes'", responseFormat)); + } } } diff --git a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java index a848a6d6f91..3e807fb2b70 100644 --- a/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java +++ b/models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java @@ -12,8 +12,19 @@ public class HuggingfaceImageOptions implements ImageOptions { private Integer height; + /** + * should be one of 'base64' or 'bytes' + */ private String responseFormat; + private String style; + + /** + * considered only if responseFormat = 'bytes' should be one of 'image/png', + * 'image/jpg', 'image/tiff' etc. + */ + private String responseMimeType; + private String negativePrompt; private Float sigmaItems; @@ -71,6 +82,23 @@ public void setResponseFormat(String responseFormat) { this.responseFormat = responseFormat; } + @Override + public String getStyle() { + return style; + } + + public void setStyle(String style) { + this.style = style; + } + + public String getResponseMimeType() { + return responseMimeType; + } + + public void setResponseMimeType(String responseMimeType) { + this.responseMimeType = responseMimeType; + } + public String getNegativePrompt() { return negativePrompt; }