Skip to content

Refactor core image #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import java.util.List;

import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageOptions;
Expand Down Expand Up @@ -109,9 +109,8 @@ private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageR
}

List<ImageGeneration> imageGenerationList = imageApiResponse.data().stream().map(entry -> {
return new ImageGeneration(new Image(entry.url(), entry.b64Json()),
new OpenAiImageGenerationMetadata(entry.revisedPrompt()));
}).toList();
return new ImageGeneration(entry.getImage(), new OpenAiImageGenerationMetadata(entry.revisedPrompt()));
}).collect(Collectors.toList());

ImageResponseMetadata openAiImageResponseMetadata = OpenAiImageResponseMetadata.from(imageApiResponse);
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.image.Image;
import org.springframework.ai.openai.image.OpenAiBase64Image;
import org.springframework.ai.openai.image.OpenAiUrlImage;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -129,6 +132,12 @@ public record Data(
@JsonProperty("url") String url,
@JsonProperty("b64_json") String b64Json,
@JsonProperty("revised_prompt") String revisedPrompt) {
// TODO : Develop Image Factory
public Image getImage() {
if (url != null) return new OpenAiUrlImage(url);
if (b64Json != null) return new OpenAiBase64Image(b64Json);
throw new IllegalArgumentException("Entry must have either url or b64Json");
}
}
// @formatter:onn

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.image;

import org.springframework.ai.image.AbstractImage;

/**
* Represents an image encoded in b64_json format within the OpenAI image processing
* context. This class extends {@link AbstractImage} to cater specifically to images that
* are provided as b64_json encoded strings. This format is useful for directly embedding
* image data within JSON or other text-based data structures without relying on external
* references.
* <p>
* An instance of this class is associated with the {@link OpenAiImageType#BASE64} image
* type, signifying that the image data is encoded in b64_json format.
* </p>
*
* @author youngmon
* @version 0.8.1
*/
public class OpenAiBase64Image extends AbstractImage<String> {

/**
* Constructs a new {@code OpenAiBase64Image} with the specified b64_json data.
* @param b64Json The Base64 encoded string that encapsulates the image data. The
* string should be a valid b64_json representation of an image file.
*/
public OpenAiBase64Image(final String b64Json) {
super(b64Json, OpenAiImageType.BASE64);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.image;

import java.util.Arrays;
import org.springframework.ai.image.ImageType;

/**
* Represents the types of images supported by the OpenAI image generation API. This enum
* provides a type-safe way to specify and retrieve the image types used in OpenAI image
* generation requests and responses.
* <p>
* It includes standard image types such as URL and Base64 encoded JSON.
* </p>
*
* @implNote This enum implements the {@link ImageType} interface, enabling it to be used
* in a generic manner across the image processing.
* @author youngmon
* @version 0.8.1
*/
public enum OpenAiImageType implements ImageType<OpenAiImageType, String> {

URL("url"), BASE64("b64_json");

private final String value;

OpenAiImageType(final String value) {
this.value = value;
}

@Override
public String getValue() {
return this.value;
}

public static OpenAiImageType fromValue(final String value) {
return Arrays.stream(values())
.filter(v -> v.value.equals(value))
.findAny()
.orElseThrow(() -> new IllegalArgumentException("Invalid Value"));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.image;

import org.springframework.ai.image.AbstractImage;

/**
* Represents an image sourced from a URL, specifically designed for use within the OpenAI
* image processing context. This class extends {@link AbstractImage} to provide a
* specialized representation for images that are accessible via web URLs.
* <p>
* Each instance of {@code OpenAiUrlImage} is associated with the
* {@link OpenAiImageType#URL} image type, indicating the source of the image data is from
* an external URL. This is particularly useful for scenarios where images need to be
* referenced rather than stored directly within the application.
* </p>
*
* @author youngmon
* @version 0.8.1
*/
public class OpenAiUrlImage extends AbstractImage<String> {

/**
* Constructs a new {@code OpenAiUrlImage} with the specified image URL.
* @param url The URL of the image. This should be a valid, fully qualified URL that
* points directly to an image file accessible over the internet.
*/
public OpenAiUrlImage(final String url) {
super(url, OpenAiImageType.URL);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ public void openAiImageTransientError() {
var result = imageClient.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678");
assertThat(result.getResult().getOutput().getData()).isEqualTo("url678");
assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ void imageAsUrlTest() {

var generation = imageResponse.getResult();
Image image = generation.getOutput();
assertThat(image.getUrl()).isNotEmpty();
// System.out.println(image.getUrl());
assertThat(image.getB64Json()).isNull();
assertThat(image.getType()).isEqualTo(OpenAiImageType.URL);
assertThat(image.getData()).isNotNull();

var imageGenerationMetadata = generation.getMetadata();
Assertions.assertThat(imageGenerationMetadata).isInstanceOf(OpenAiImageGenerationMetadata.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.image;

import static org.mockito.Mockito.when;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.image.Image;
import org.springframework.ai.openai.api.OpenAiImageApi.*;

import static org.assertj.core.api.Assertions.assertThat;

@ExtendWith(MockitoExtension.class)
public class OpenAiImageCreationTests {

@Mock
OpenAiImageRequest req;

@Mock
OpenAiImageResponse res;

@Test
public void urlCreationTest() {
String testData = "test.url";

when(req.responseFormat()).thenReturn("url");
when(res.toString()).thenReturn(testData);

Image urlImg = tmpImageFactory(req, res);

assertThat(urlImg).isNotNull();

assertThat(urlImg.getType()).isEqualTo(OpenAiImageType.URL);
assertThat(urlImg.getType().getValue()).isEqualTo("url");
assertThat(urlImg.getData()).isEqualTo(testData);
}

@Test
public void b64CreationTest() {
String testData = "test.b64";

when(req.responseFormat()).thenReturn("b64_json");
when(res.toString()).thenReturn(testData);

Image b64Img = tmpImageFactory(req, res);

assertThat(b64Img).isNotNull();

assertThat(b64Img.getType()).isEqualTo(OpenAiImageType.BASE64);
assertThat(b64Img.getType().getValue()).isEqualTo("b64_json");
assertThat(b64Img.getData()).isEqualTo(testData);
}

private Image tmpImageFactory(OpenAiImageRequest req, OpenAiImageResponse res) {
return switch (OpenAiImageType.fromValue(req.responseFormat())) {
case BASE64 -> new OpenAiBase64Image(res.toString());
case URL -> new OpenAiUrlImage(res.toString());
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageOptions;
Expand Down Expand Up @@ -115,7 +114,7 @@ private static StabilityAiApi.GenerateImageRequest getGenerateImageRequest(Image

private ImageResponse convertResponse(StabilityAiApi.GenerateImageResponse generateImageResponse) {
List<ImageGeneration> imageGenerationList = generateImageResponse.artifacts().stream().map(entry -> {
return new ImageGeneration(new Image(null, entry.base64()),
return new ImageGeneration(entry.getImage(),
new StabilityAiImageGenerationMetadata(entry.finishReason(), entry.seed()));
}).toList();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.image.Image;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.stabilityai.image.StabilityAiBase64Image;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -190,6 +192,12 @@ public record GenerateImageResponse(@JsonProperty("result") String result,
@JsonProperty("artifacts") List<Artifacts> artifacts) {
public record Artifacts(@JsonProperty("seed") long seed, @JsonProperty("base64") String base64,
@JsonProperty("finishReason") String finishReason) {
// TODO : Develop Image Factory
public Image getImage() {
if (base64 != null)
return new StabilityAiBase64Image(base64);
throw new IllegalArgumentException("Entry must have base64");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.stabilityai.image;

import org.springframework.ai.image.AbstractImage;

/**
* Represents an image encoded in Base64 format used within the Stability AI context. This
* class extends {@link AbstractImage} with a specific focus on images that are
* represented as Base64 encoded strings. It is particularly useful for handling image
* data that is transmitted over networks where binary data needs to be encoded as text.
* <p>
* Instances of this class are associated with the {@link StabilityAiImageType#BASE64}
* image type, indicating that the image data is in Base64 format.
* </p>
*
* @author youngmon
* @version 0.8.1
*/
public class StabilityAiBase64Image extends AbstractImage<String> {

/**
* Constructs a new {@code StabilityAiBase64Image} with the specified Base64 image
* data.
* @param data The Base64 encoded string that represents the image. This should be a
* valid Base64 encoding of an image file's contents.
*/
public StabilityAiBase64Image(final String data) {
super(data, StabilityAiImageType.BASE64);
}

}
Loading