-
Notifications
You must be signed in to change notification settings - Fork 322
/
Copy pathtest_multi_scale_sampler.py
65 lines (50 loc) · 1.73 KB
/
test_multi_scale_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import paddle
import paddle.distributed as dist
from paddle.io import Dataset
from paddle.io import DataLoader
from multi_scale_sampler import MultiScaleSamplerDDP
class DummyDataset(Dataset):
def __init__(self):
super().__init__()
def __getitem__(self, index):
w, h, idx = index
#print(f'inside dummydataset(local_rank: {dist.get_rank()}): {index}')
#data = paddle.randn([3, 224, 224])
data = paddle.randn([3, int(w), int(h)])
label = paddle.randn([1])
return data, label
def __len__(self):
return 5000
def get_dataset():
dataset = DummyDataset()
return dataset
#def collate_fn(batch_data_list):
# for batch_data in batch_data_list:
# print('collate: ', batch_data[0].shape)
#
# return paddle.to_tensor(batch_data_list[0]), paddle.to_tensor(batch_data_list[1])
def get_dataloader(dataset):
sampler = MultiScaleSamplerDDP(224, 224, 4, 5000, is_train=True)
dataloader = DataLoader(dataset,
batch_sampler=sampler,
#collate_fn=collate_fn,
num_workers=1)
return dataloader
def main_worker(*args):
dataset = args[0]
dist.init_parallel_env()
dataloader = get_dataloader(dataset)
local_rank = dist.get_rank()
for batch_id, data in enumerate(dataloader):
#print('local_rank = ', local_rank, ', batch_id =', batch_id)
#print(data[0].shape, data[1].shape)
#print('-----')
#if batch_id ==10:
# break
break
def main():
dataset_val = get_dataset()
ngpus = len(paddle.static.cuda_places())
dist.spawn(main_worker, args=(dataset_val, ), nprocs=ngpus)
if __name__ == "__main__":
main()