Skip to content

Commit 7929072

Browse files
NielsRoggeWauplingithub-actions[bot]hanouticelina
authored
Add PIL Image support to InferenceClient (#3199)
* Add PIL Image support to InferenceClient - Add PIL Image to ContentT type hints for type checking - Update _open_as_binary to handle PIL Images by converting to bytes - Update _as_url to detect MIME type from PIL Image format - Update docstrings to indicate PIL Image support for all image methods - Fixes #3191: Make InferenceClient accept Pillow images This enables iterative image editing workflows where PIL Images returned by image_to_image can be directly passed back to image_to_image methods without the 'Unsupported input type for image' error. * Make style * Address comments * Apply suggestion from @Wauplin * Apply suggestion from @Wauplin * Apply style fixes * Apply suggestions from code review Co-authored-by: célina <hanouticelina@gmail.com> * add tests * style --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: célina <hanouticelina@gmail.com>
1 parent 7660014 commit 7929072

File tree

4 files changed

+86
-38
lines changed

4 files changed

+86
-38
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,8 +1154,8 @@ def image_classification(
11541154
Perform image classification on the given image using the specified model.
11551155
11561156
Args:
1157-
image (`Union[str, Path, bytes, BinaryIO]`):
1158-
The image to classify. It can be raw bytes, an image file, or a URL to an online image.
1157+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1158+
The image to classify. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
11591159
model (`str`, *optional*):
11601160
The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
11611161
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
@@ -1212,8 +1212,8 @@ def image_segmentation(
12121212
</Tip>
12131213
12141214
Args:
1215-
image (`Union[str, Path, bytes, BinaryIO]`):
1216-
The image to segment. It can be raw bytes, an image file, or a URL to an online image.
1215+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1216+
The image to segment. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
12171217
model (`str`, *optional*):
12181218
The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
12191219
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
@@ -1284,8 +1284,8 @@ def image_to_image(
12841284
</Tip>
12851285
12861286
Args:
1287-
image (`Union[str, Path, bytes, BinaryIO]`):
1288-
The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
1287+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1288+
The input image for translation. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
12891289
prompt (`str`, *optional*):
12901290
The text prompt to guide the image generation.
12911291
negative_prompt (`str`, *optional*):
@@ -1347,8 +1347,8 @@ def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> Imag
13471347
(OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
13481348
13491349
Args:
1350-
image (`Union[str, Path, bytes, BinaryIO]`):
1351-
The input image to caption. It can be raw bytes, an image file, or a URL to an online image..
1350+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1351+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
13521352
model (`str`, *optional*):
13531353
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
13541354
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
@@ -1398,8 +1398,8 @@ def object_detection(
13981398
</Tip>
13991399
14001400
Args:
1401-
image (`Union[str, Path, bytes, BinaryIO]`):
1402-
The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image.
1401+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1402+
The image to detect objects on. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
14031403
model (`str`, *optional*):
14041404
The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
14051405
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
@@ -2973,8 +2973,8 @@ def visual_question_answering(
29732973
Answering open-ended questions based on an image.
29742974
29752975
Args:
2976-
image (`Union[str, Path, bytes, BinaryIO]`):
2977-
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
2976+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
2977+
The input image for the context. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
29782978
question (`str`):
29792979
Question to be answered.
29802980
model (`str`, *optional*):
@@ -3140,8 +3140,8 @@ def zero_shot_image_classification(
31403140
Provide input image and text labels to predict text labels for the image.
31413141
31423142
Args:
3143-
image (`Union[str, Path, bytes, BinaryIO]`):
3144-
The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
3143+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
3144+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
31453145
candidate_labels (`List[str]`):
31463146
The candidate labels for this image
31473147
labels (`List[str]`, *optional*):

src/huggingface_hub/inference/_common.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
UrlT = str
6363
PathT = Union[str, Path]
6464
BinaryT = Union[bytes, BinaryIO]
65-
ContentT = Union[BinaryT, PathT, UrlT]
65+
ContentT = Union[BinaryT, PathT, UrlT, "Image"]
6666

6767
# Use to set a Accept: image/png header
6868
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
@@ -161,11 +161,10 @@ def _open_as_binary(
161161

162162
@contextmanager # type: ignore
163163
def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
164-
"""Open `content` as a binary file, either from a URL, a local path, or raw bytes.
164+
"""Open `content` as a binary file, either from a URL, a local path, raw bytes, or a PIL Image.
165165
166-
Do nothing if `content` is None,
166+
Do nothing if `content` is None.
167167
168-
TODO: handle a PIL.Image as input
169168
TODO: handle base64 as input
170169
"""
171170
# If content is a string => must be either a URL or a path
@@ -186,9 +185,21 @@ def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT],
186185
logger.debug(f"Opening content from {content}")
187186
with content.open("rb") as f:
188187
yield f
189-
else:
190-
# Otherwise: already a file-like object or None
191-
yield content
188+
return
189+
190+
# If content is a PIL Image => convert to bytes
191+
if is_pillow_available():
192+
from PIL import Image
193+
194+
if isinstance(content, Image.Image):
195+
logger.debug("Converting PIL Image to bytes")
196+
buffer = io.BytesIO()
197+
content.save(buffer, format=content.format or "PNG")
198+
yield buffer.getvalue()
199+
return
200+
201+
# Otherwise: already a file-like object or None
202+
yield content # type: ignore
192203

193204

194205
def _b64_encode(content: ContentT) -> str:
@@ -202,9 +213,18 @@ def _as_url(content: ContentT, default_mime_type: str) -> str:
202213
if isinstance(content, str) and (content.startswith("https://") or content.startswith("http://")):
203214
return content
204215

205-
mime_type = (
206-
mimetypes.guess_type(content, strict=False)[0] if isinstance(content, (str, Path)) else None
207-
) or default_mime_type
216+
# Handle MIME type detection for different content types
217+
mime_type = None
218+
if isinstance(content, (str, Path)):
219+
mime_type = mimetypes.guess_type(content, strict=False)[0]
220+
elif is_pillow_available():
221+
from PIL import Image
222+
223+
if isinstance(content, Image.Image):
224+
# Determine MIME type from PIL Image format, in sync with `_open_as_binary`
225+
mime_type = f"image/{(content.format or 'PNG').lower()}"
226+
227+
mime_type = mime_type or default_mime_type
208228
encoded_data = _b64_encode(content)
209229
return f"data:{mime_type};base64,{encoded_data}"
210230

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,8 +1197,8 @@ async def image_classification(
11971197
Perform image classification on the given image using the specified model.
11981198
11991199
Args:
1200-
image (`Union[str, Path, bytes, BinaryIO]`):
1201-
The image to classify. It can be raw bytes, an image file, or a URL to an online image.
1200+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1201+
The image to classify. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
12021202
model (`str`, *optional*):
12031203
The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a
12041204
deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used.
@@ -1256,8 +1256,8 @@ async def image_segmentation(
12561256
</Tip>
12571257
12581258
Args:
1259-
image (`Union[str, Path, bytes, BinaryIO]`):
1260-
The image to segment. It can be raw bytes, an image file, or a URL to an online image.
1259+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1260+
The image to segment. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
12611261
model (`str`, *optional*):
12621262
The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a
12631263
deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used.
@@ -1329,8 +1329,8 @@ async def image_to_image(
13291329
</Tip>
13301330
13311331
Args:
1332-
image (`Union[str, Path, bytes, BinaryIO]`):
1333-
The input image for translation. It can be raw bytes, an image file, or a URL to an online image.
1332+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1333+
The input image for translation. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
13341334
prompt (`str`, *optional*):
13351335
The text prompt to guide the image generation.
13361336
negative_prompt (`str`, *optional*):
@@ -1393,8 +1393,8 @@ async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -
13931393
(OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities.
13941394
13951395
Args:
1396-
image (`Union[str, Path, bytes, BinaryIO]`):
1397-
The input image to caption. It can be raw bytes, an image file, or a URL to an online image..
1396+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1397+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
13981398
model (`str`, *optional*):
13991399
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
14001400
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.
@@ -1445,8 +1445,8 @@ async def object_detection(
14451445
</Tip>
14461446
14471447
Args:
1448-
image (`Union[str, Path, bytes, BinaryIO]`):
1449-
The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image.
1448+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
1449+
The image to detect objects on. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
14501450
model (`str`, *optional*):
14511451
The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a
14521452
deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used.
@@ -3033,8 +3033,8 @@ async def visual_question_answering(
30333033
Answering open-ended questions based on an image.
30343034
30353035
Args:
3036-
image (`Union[str, Path, bytes, BinaryIO]`):
3037-
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
3036+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
3037+
The input image for the context. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
30383038
question (`str`):
30393039
Question to be answered.
30403040
model (`str`, *optional*):
@@ -3203,8 +3203,8 @@ async def zero_shot_image_classification(
32033203
Provide input image and text labels to predict text labels for the image.
32043204
32053205
Args:
3206-
image (`Union[str, Path, bytes, BinaryIO]`):
3207-
The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
3206+
image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`):
3207+
The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image.
32083208
candidate_labels (`List[str]`):
32093209
The candidate labels for this image
32103210
labels (`List[str]`, *optional*):

tests/test_inference_client.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import base64
1415
import io
1516
import json
1617
import os
@@ -803,6 +804,15 @@ def test_open_as_binary_from_bytes(self) -> None:
803804
with _open_as_binary(content_bytes) as content:
804805
assert content == content_bytes
805806

807+
def test_open_as_binary_from_pil_image(self) -> None:
808+
pil_image = Image.open(self.image_file)
809+
with _open_as_binary(pil_image) as content:
810+
assert isinstance(content, bytes)
811+
812+
buffer = io.BytesIO()
813+
pil_image.save(buffer, format=pil_image.format or "PNG")
814+
assert content == buffer.getvalue()
815+
806816

807817
class TestHeadersAndCookies(TestBase):
808818
def test_headers_and_cookies(self) -> None:
@@ -1213,3 +1223,21 @@ def test_as_url(content_input, default_mime_type, expected, is_exact_match, tmp_
12131223
assert result == expected
12141224
else:
12151225
assert result.startswith(expected)
1226+
1227+
1228+
def test_as_url_with_pil_image(image_file: str):
1229+
"""Test `_as_url` helper with a PIL Image."""
1230+
pil_image = Image.open(image_file)
1231+
1232+
pil_image.format = "PNG"
1233+
png_url = _as_url(pil_image, default_mime_type="image/jpeg")
1234+
assert png_url.startswith("data:image/png;base64,")
1235+
1236+
pil_image.format = None
1237+
png_url = _as_url(pil_image, default_mime_type="image/jpeg")
1238+
assert png_url.startswith("data:image/png;base64,")
1239+
1240+
buffer = io.BytesIO()
1241+
pil_image.save(buffer, format="PNG")
1242+
b64_encoded = base64.b64encode(buffer.getvalue()).decode()
1243+
assert png_url == f"data:image/png;base64,{b64_encoded}"

0 commit comments

Comments
 (0)