Skip to content

Commit 6c8dad1

Browse files
fix: preserve quality and optimize transfer of prompt images (#570)
* fix: preserve quality and optimize transfer of prompt images * Move numpy-images to their own test case. Change-Id: Ie6b02c7647487c1df9d4e70e9b8eed70dc8b8fe3 * Format with black Change-Id: I04550a89eed9bb21c0a8f6f9b6ab76b8b0f41270 --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent 8f7f5cb commit 6c8dad1

File tree

2 files changed

+43
-40
lines changed

2 files changed

+43
-40
lines changed

google/generativeai/types/content_types.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io
2020
import inspect
2121
import mimetypes
22+
import pathlib
2223
import typing
2324
from typing import Any, Callable, Union
2425
from typing_extensions import TypedDict
@@ -30,15 +31,15 @@
3031

3132
if typing.TYPE_CHECKING:
3233
import PIL.Image
33-
import PIL.PngImagePlugin
34+
import PIL.ImageFile
3435
import IPython.display
3536

3637
IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
3738
else:
3839
IMAGE_TYPES = ()
3940
try:
4041
import PIL.Image
41-
import PIL.PngImagePlugin
42+
import PIL.ImageFile
4243

4344
IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
4445
except ImportError:
@@ -72,46 +73,39 @@
7273
]
7374

7475

75-
def pil_to_blob(img):
76-
# When you load an image with PIL you get a subclass of PIL.Image
77-
# The subclass knows what file type it was loaded from it has a `.format` class attribute
78-
# and the `get_format_mimetype` method. Convert these back to the same file type.
79-
#
80-
# The base image class doesn't know its file type, it just knows its mode.
81-
# RGBA converts to PNG easily, P[allet] converts to GIF, RGB to GIF.
82-
# But for anything else I'm not going to bother mapping it out (for now) let's just convert to RGB and send it.
83-
#
84-
# References:
85-
# - file formats: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html
86-
# - image modes: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
87-
88-
bytesio = io.BytesIO()
89-
90-
get_mime = getattr(img, "get_format_mimetype", None)
91-
if get_mime is not None:
92-
# If the image is created from a file, convert back to the same file type.
93-
img.save(bytesio, format=img.format)
94-
mime_type = img.get_format_mimetype()
95-
elif img.mode == "RGBA":
96-
img.save(bytesio, format="PNG")
97-
mime_type = "image/png"
98-
elif img.mode == "P":
99-
img.save(bytesio, format="GIF")
100-
mime_type = "image/gif"
101-
else:
102-
if img.mode != "RGB":
103-
img = img.convert("RGB")
104-
img.save(bytesio, format="JPEG")
105-
mime_type = "image/jpeg"
106-
bytesio.seek(0)
107-
data = bytesio.read()
108-
return protos.Blob(mime_type=mime_type, data=data)
76+
def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
77+
# If the image is a local file, return a file-based blob without any modification.
78+
# Otherwise, return a lossless WebP blob (same quality with optimized size).
79+
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
80+
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
81+
return None
82+
filename = str(image.filename)
83+
if not pathlib.Path(filename).is_file():
84+
return None
85+
86+
mime_type = image.get_format_mimetype()
87+
image_bytes = pathlib.Path(filename).read_bytes()
88+
89+
return protos.Blob(mime_type=mime_type, data=image_bytes)
90+
91+
def webp_blob(image: PIL.Image.Image) -> protos.Blob:
92+
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
93+
image_io = io.BytesIO()
94+
image.save(image_io, format="webp", lossless=True)
95+
image_io.seek(0)
96+
97+
mime_type = "image/webp"
98+
image_bytes = image_io.read()
99+
100+
return protos.Blob(mime_type=mime_type, data=image_bytes)
101+
102+
return file_blob(image) or webp_blob(image)
109103

110104

111105
def image_to_blob(image) -> protos.Blob:
112106
if PIL is not None:
113107
if isinstance(image, PIL.Image.Image):
114-
return pil_to_blob(image)
108+
return _pil_to_blob(image)
115109

116110
if IPython is not None:
117111
if isinstance(image, IPython.display.Image):

tests/test_content.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,20 @@ class HasEnum:
8383

8484

8585
class UnitTests(parameterized.TestCase):
86+
8687
@parameterized.named_parameters(
87-
["PIL", PIL.Image.open(TEST_PNG_PATH)],
8888
["RGBA", PIL.Image.fromarray(np.zeros([6, 6, 4], dtype=np.uint8))],
89+
["RGB", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8))],
90+
["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")],
91+
)
92+
def test_numpy_to_blob(self, image):
93+
blob = content_types.image_to_blob(image)
94+
self.assertIsInstance(blob, protos.Blob)
95+
self.assertEqual(blob.mime_type, "image/webp")
96+
self.assertStartsWith(blob.data, b"RIFF \x00\x00\x00WEBPVP8L")
97+
98+
@parameterized.named_parameters(
99+
["PIL", PIL.Image.open(TEST_PNG_PATH)],
89100
["IPython", IPython.display.Image(filename=TEST_PNG_PATH)],
90101
)
91102
def test_png_to_blob(self, image):
@@ -96,7 +107,6 @@ def test_png_to_blob(self, image):
96107

97108
@parameterized.named_parameters(
98109
["PIL", PIL.Image.open(TEST_JPG_PATH)],
99-
["RGB", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8))],
100110
["IPython", IPython.display.Image(filename=TEST_JPG_PATH)],
101111
)
102112
def test_jpg_to_blob(self, image):
@@ -107,7 +117,6 @@ def test_jpg_to_blob(self, image):
107117

108118
@parameterized.named_parameters(
109119
["PIL", PIL.Image.open(TEST_GIF_PATH)],
110-
["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")],
111120
["IPython", IPython.display.Image(filename=TEST_GIF_PATH)],
112121
)
113122
def test_gif_to_blob(self, image):

0 commit comments

Comments
 (0)