Skip to content

Commit 9d2cee8

Browse files
CLIPFeatureExtractor should resize images with kept aspect ratio (#11994)
* Resize with kept aspect ratio * Fixed failed test * Overload center_crop and resize methods instead * resize should handle non-PIL images * update slow test * Tensor => tensor Co-authored-by: patil-suraj <surajp815@gmail.com>
1 parent 472a867 commit 9d2cee8

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

src/transformers/models/clip/feature_extraction_clip.py

+53
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,56 @@ def __call__(
154154
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
155155

156156
return encoded_inputs
157+
158+
def center_crop(self, image, size):
159+
"""
160+
Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to
161+
the size is given, it will be padded (so the returned result has the size asked).
162+
163+
Args:
164+
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
165+
The image to resize.
166+
size (:obj:`int` or :obj:`Tuple[int, int]`):
167+
The size to which crop the image.
168+
"""
169+
self._ensure_format_supported(image)
170+
if not isinstance(size, tuple):
171+
size = (size, size)
172+
173+
if not isinstance(image, Image.Image):
174+
image = self.to_pil_image(image)
175+
176+
image_width, image_height = image.size
177+
crop_height, crop_width = size
178+
179+
crop_top = int((image_height - crop_height + 1) * 0.5)
180+
crop_left = int((image_width - crop_width + 1) * 0.5)
181+
182+
return image.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
183+
184+
def resize(self, image, size, resample=Image.BICUBIC):
185+
"""
186+
Resizes :obj:`image`. Note that this will trigger a conversion of :obj:`image` to a PIL Image.
187+
188+
Args:
189+
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
190+
The image to resize.
191+
size (:obj:`int` or :obj:`Tuple[int, int]`):
192+
The size to use for resizing the image. If :obj:`int` it will be resized to match the shorter side
193+
resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`):
194+
The filter to user for resampling.
195+
"""
196+
self._ensure_format_supported(image)
197+
198+
if not isinstance(image, Image.Image):
199+
image = self.to_pil_image(image)
200+
if isinstance(size, tuple):
201+
new_w, new_h = size
202+
else:
203+
width, height = image.size
204+
short, long = (width, height) if width <= height else (height, width)
205+
if short == size:
206+
return image
207+
new_short, new_long = size, int(size * long / short)
208+
new_w, new_h = (new_short, new_long) if width <= height else (new_long, new_short)
209+
return image.resize((new_w, new_h), resample)

tests/test_modeling_clip.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ def test_inference(self):
544544
).to(torch_device)
545545

546546
# forward pass
547-
outputs = model(**inputs)
547+
with torch.no_grad():
548+
outputs = model(**inputs)
548549

549550
# verify the logits
550551
self.assertEqual(
@@ -556,6 +557,6 @@ def test_inference(self):
556557
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
557558
)
558559

559-
expected_logits = torch.tensor([[24.5056, 18.8076]], device=torch_device)
560+
expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device)
560561

561562
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))

0 commit comments

Comments
 (0)