Skip to content

Commit ec43437

Browse files
committed
torch cond
1 parent 8ea4772 commit ec43437

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torchvision/transforms/_functional_tensor.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -919,8 +919,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
919919
dtype = tensor.dtype
920920
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
921921
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
922-
if (std == 0).any():
923-
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
922+
923+
def stdzero():
924+
raise ValueError(
925+
f"std evaluated to zero after conversion to {dtype}, leading to division by zero."
926+
)
927+
928+
torch.cond((std == 0).any(), stdzero, lambda: None)
924929
if mean.ndim == 1:
925930
mean = mean.view(-1, 1, 1)
926931
if std.ndim == 1:

0 commit comments

Comments
 (0)