@@ -662,27 +662,39 @@ class VideoDatasetTestCase(DatasetTestCase):
662
662
FEATURE_TYPES = (torch .Tensor , torch .Tensor , int )
663
663
REQUIRED_PACKAGES = ("av" ,)
664
664
665
- DEFAULT_FRAMES_PER_CLIP = 1
665
+ FRAMES_PER_CLIP = 1
666
666
667
667
def __init__ (self , * args , ** kwargs ):
668
668
super ().__init__ (* args , ** kwargs )
669
669
self .dataset_args = self ._set_default_frames_per_clip (self .dataset_args )
670
670
671
- def _set_default_frames_per_clip (self , inject_fake_data ):
671
+ def _set_default_frames_per_clip (self , dataset_args ):
672
672
argspec = inspect .getfullargspec (self .DATASET_CLASS .__init__ )
673
673
args_without_default = argspec .args [1 : (- len (argspec .defaults ) if argspec .defaults else None )]
674
674
frames_per_clip_last = args_without_default [- 1 ] == "frames_per_clip"
675
675
676
- @functools .wraps (inject_fake_data )
676
+ @functools .wraps (dataset_args )
677
677
def wrapper (tmpdir , config ):
678
- args = inject_fake_data (tmpdir , config )
678
+ args = dataset_args (tmpdir , config )
679
679
if frames_per_clip_last and len (args ) == len (args_without_default ) - 1 :
680
- args = (* args , self .DEFAULT_FRAMES_PER_CLIP )
680
+ args = (* args , self .FRAMES_PER_CLIP )
681
681
682
682
return args
683
683
684
684
return wrapper
685
685
686
+ def test_output_format (self ):
687
+ for output_format in ["TCHW" , "THWC" ]:
688
+ with self .create_dataset (output_format = output_format ) as (dataset , _ ):
689
+ for video , * _ in dataset :
690
+ if output_format == "TCHW" :
691
+ num_frames , num_channels , * _ = video .shape
692
+ else : # output_format == "THWC":
693
+ num_frames , * _ , num_channels = video .shape
694
+
695
+ assert num_frames == self .FRAMES_PER_CLIP
696
+ assert num_channels == 3
697
+
686
698
@test_all_configs
687
699
def test_transforms_v2_wrapper (self , config ):
688
700
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
0 commit comments