Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchvision transforms do not scale but the docs say that they do #8984

Closed
Stochastic13 opened this issue Mar 18, 2025 · 5 comments
Closed

Comments

@Stochastic13
Copy link

🐛 Describe the bug

The docs for torchvision alexnet mention that the transforms rescale the values to 0...1 before applying the mean and std scaling. However, this is not the case. Looking at the source code, in transforms.__presets.py, the ImageClassification class describes itself as doing this rescaling but I do not see the corresponding code. Presumably the Alexnet transform inherits this directly. Apologies if I am missing something here!

Here is a reproducible code:

from torchvision.models import AlexNet_Weights
from torchvision import transforms
import torch
import numpy as np
p1 = AlexNet_Weights.IMAGENET1K_V1.transforms()
p2 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
np.random.seed(42)
n = np.random.randint(0, 255, (1, 3, 224, 224))

for label, p in zip(['AlexNet_Weights', 'Compose'], [p1, p2]):
    for label2, dtype in zip(['int', 'float'], [np.int64, np.float32]):
        for labeln, norm in zip(['no norm', '0...1'], [1, 255]):  
            print(f'preprocessor: {label}, dtype: {label2}, norm: {labeln}:')
            print(f'\t{torch.min(p(torch.from_numpy(n.astype(dtype) / norm)))}')
            print(f'\t{torch.max(p(torch.from_numpy(n.astype(dtype) / norm)))}')

Image

Versions

torchvision: 0.21.0
torch: 2.6.0
numpy: 2.2.2

@abhi-glitchhg
Copy link
Contributor

abhi-glitchhg commented Mar 18, 2025

I will look into the reproducible code that you have shared (maybe tomorrow) but I'm sure we do scale the inputs.
Refer to this line.

img = F.convert_image_dtype(img, torch.float)

F.convert_image_dtype handles the scaling.

https://pytorch.org/vision/main/generated/torchvision.transforms.functional.convert_image_dtype.html

I will go through your shared code tom, it's pretty late rn.

@Stochastic13
Copy link
Author

Stochastic13 commented Mar 18, 2025

@abhi-glitchhg Thanks for looking at this!

In the code you cite, I do not understand why are we scaling the values to 0..1 based on dtypes.
If the input is already floating point, I do not see any normalization

https://github.com/pytorch/vision/blob/8ea4772e97bc11b2cfee48a415e7df8cd17fa682/torchvision/transforms/_functional_tensor.py#L70C1-L91C32

If the input is integer, I see a normalization based on the maximum value of the dtype.

Is this the intended behavior? The docs say

Finally the values are first rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].

I took this to mean that if I input my image with values 0 - 255 in any dtype, they would be rescaled to [0, 1]. If the input is integer, we assume the max is 255 irrespective of the dtype and raise an error if there are values exceeding 255. If float, and there are values exceeding 1, consider it again to be between 0 - 255. This would make sense and be what would be expected from an image preprocessor.

Apologies if I misunderstood the doc. If that is the case, I recommend mentioning this in the documentation to clarify

@abhi-glitchhg
Copy link
Contributor

abhi-glitchhg commented Mar 19, 2025

If the input is integer, I see a normalization based on the maximum value of the dtype.

Correct.

If the input is already floating point, I do not see any normalization

Correct.

generally the images are of type uint8 or sometimes uint16 if you work in astronomy or medical domain. Hence the maximum possible values for those dtype is 255 and 65025.

I took this to mean that if I input my image with values 0 - 255 in any dtype, they would be rescaled to [0, 1]

No, it depends on the dtype of the input. if the input image is of dtype uint8 and data is in range(0-255) then only it would be rescaled to [0,1]. Heres one example. I have two arrays which have data between 0-255, but one is of dtype uint16 and another one is uint8.

import torchvision.transforms.functional as F
int16_input = torch.randint(0,255,(1,3,224,224)).to(dtype=torch.uint16)
int8_input = int16_input.to(dtype=torch.uint8)

int16_output= F.convert_image_dtype(int16_input)
int8_output= F.convert_image_dtype(int8_input)

print(int16_output)
print("#")20
print(int8_output)

But if you are providing a float dtype with 0-255. values then no operation is done.

If the input is integer, we assume the max is 255 irrespective of the dtype and raise an error if there are values exceeding 255. This would make sense and be what would be expected from an image preprocessor.

As i mentioned above, there are many types of integers, int8, uint8, uint16 and each one has its different maximum value possible.

If float, and there are values exceeding 1, consider it again to be between 0 - 255.

This is dangerous, as i mentioned there are different types of integer datatypes and assuming the data to be between 0-255 (or any range) is not possible.

===============================================

Coming back to the example code you have shared,

print(f'\t{torch.min(p(torch.from_numpy(n.astype(dtype) / norm)))}')

I think the problem is here, you are dividing the input by norm which converts the numpy array to dtype of float irrespective of your typecasting (astype) operation.

Maybe thats why there was no scaling done by convert_image_dtype and p1 and p2 were giving identical results.

I hope i have addressed the issue, if anything is unclear, feel free to comment.

@abhi-glitchhg
Copy link
Contributor

abhi-glitchhg commented Mar 19, 2025

im assuming you are having input that is in between 0-255;make sure that the dtype of input is uint8.

We do have this documented as well.
Check out this section - https://pytorch.org/vision/main/transforms.html#dtype-and-expected-value-range

@Stochastic13
Copy link
Author

@abhi-glitchhg Thanks for the clarification!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants