Skip to content

Commit d330d75

Browse files
authored
Fix RAFT input dimension check (#8851)
1 parent 36febf5 commit d330d75

File tree

1 file changed

+1
-1
lines changed
  • torchvision/models/optical_flow

1 file changed

+1
-1
lines changed

torchvision/models/optical_flow/raft.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def forward(self, image1, image2, num_flow_updates: int = 12):
486486
batch_size, _, h, w = image1.shape
487487
if (h, w) != image2.shape[-2:]:
488488
raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}")
489-
if not (h % 8 == 0) and (w % 8 == 0):
489+
if not ((h % 8 == 0) and (w % 8 == 0)):
490490
raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)")
491491

492492
fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))

0 commit comments

Comments
 (0)