From 4e4ed6a7a19195bd5d47f2597f0a480c32e541eb Mon Sep 17 00:00:00 2001 From: Ion Linti Date: Fri, 2 May 2025 16:10:34 +0300 Subject: [PATCH 1/2] Fix #9051: add integer dtype check to RandomPosterize MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit – Introduce helper _ensure_integer_dtype to validate tensor.dtype via torch.iinfo – Call this helper in RandomPosterize.transform to raise TypeError on non-integer dtypes --- torchvision/transforms/v2/_color.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index bf4ae55d232..198e03398fe 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -3,12 +3,26 @@ from typing import Any, Optional, Union import torch +from torch import Tensor from torchvision import transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform + from ._transform import _RandomApplyTransform from ._utils import query_chw +def _ensure_integer_dtype(tensor: Tensor) -> None: + """ + Checks that the tensor's dtype is integer. + Throws TypeError for float, complex, bool, etc. + """ + try: + torch.iinfo(tensor.dtype) + except (ValueError, TypeError): + raise TypeError( + f"Number of value bits is only defined for integer dtypes, but got {tensor.dtype}" + ) + class Grayscale(Transform): """Convert images or videos to grayscale. @@ -306,9 +320,11 @@ def __init__(self, bits: int, p: float = 0.5) -> None: self.bits = bits def transform(self, inpt: Any, params: dict[str, Any]) -> Any: + # Check that the tensor is integer + if isinstance(inpt, Tensor): + _ensure_integer_dtype(inpt) return self._call_kernel(F.posterize, inpt, bits=self.bits) - class RandomSolarize(_RandomApplyTransform): """Solarize the image or video with a given probability by inverting all pixel values above a threshold. From 9cb6f11d4523f1b2956bb95c592ab331be698a9e Mon Sep 17 00:00:00 2001 From: Ion Linti Date: Fri, 2 May 2025 16:20:11 +0300 Subject: [PATCH 2/2] Add tests for RandomPosterize dtype validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit – Test that RandomPosterize(bits=3, p=1.0) raises TypeError on float32, bool, complex64 with message “Number of value bits is only defined for integer dtypes” – Test that RandomPosterize(bits=4, p=1.0) successfully processes a uint8 tensor --- test/test_transforms_v2.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 94d90b9e2f6..c87feef11c7 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -19,6 +19,7 @@ import torch import torchvision.ops import torchvision.transforms.v2 as transforms +from torchvision.transforms.v2 import RandomPosterize from common_utils import ( assert_equal, @@ -6270,4 +6271,20 @@ def test_different_sizes(self, make_input1, make_input2, query): @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): - query(["blah"]) + query(["blah"] + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bool, torch.complex64]) +def test_random_posterize_dtype_error(dtype): + rp = RandomPosterize(bits=3, p=1.0) + tensor = torch.zeros((1, 3, 5, 5), dtype=dtype) + with pytest.raises(TypeError) as excinfo: + rp(tensor) + assert "Number of value bits is only defined for integer dtypes" in str(excinfo.value) + + +def test_random_posterize_uint8_pass(): + rp = RandomPosterize(bits=4, p=1.0) + tensor = torch.randint(0, 255, (1, 3, 5, 5), dtype=torch.uint8) + out = rp(tensor) + assert isinstance(out, torch.Tensor) + assert out.dtype == torch.uint8