Skip to content

Commit 0b05122

Browse files
committed
Fixing hieradet (sam2) tests
1 parent e035381 commit 0b05122

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@
5252
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5353
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5454
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
55-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit',
55+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2'
5656
]
5757

5858
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
5959
NON_STD_FILTERS = [
6060
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6161
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
62-
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
62+
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
6363
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
6464
]
6565
NUM_NON_STD = len(NON_STD_FILTERS)

timm/models/hieradet_sam2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def window_unpartition(windows: torch.Tensor, window_size: Tuple[int, int], hw:
4545
"""
4646
H, W = hw
4747
B = windows.shape[0] // (H * W // window_size[0] // window_size[1])
48-
x = windows.view(B, H // window_size[0], W // window_size[0], window_size[0], window_size[1], -1)
48+
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
4949
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
5050
return x
5151

@@ -567,11 +567,12 @@ def checkpoint_filter_fn(state_dict, model=None, prefix=''):
567567

568568
def _create_hiera_det(variant: str, pretrained: bool = False, **kwargs) -> HieraDet:
569569
out_indices = kwargs.pop('out_indices', 4)
570-
if True: # kwargs.get('pretrained_cfg', '') == '?':
570+
checkpoint_prefix = ''
571+
if 'sam2' in variant:
571572
# SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`)
572573
# This is workaround loading with num_classes=0 w/o removing norm-layer.
573574
kwargs.setdefault('pretrained_strict', False)
574-
checkpoint_prefix = 'image_encoder.trunk.' if 'sam2' in variant else ''
575+
checkpoint_prefix = 'image_encoder.trunk.'
575576
return build_model_with_cfg(
576577
HieraDet,
577578
variant,

0 commit comments

Comments
 (0)