-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathsampler.py
119 lines (95 loc) · 3.47 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
class SingleImageSampler:
def __init__(self, batch_size, N_img, N_pixels, i_validation, tpu_num):
self.batch_size = batch_size
self.N_pixels = N_pixels
self.N_img = N_img
self.drop_last = False
self.i_validation = i_validation
self.tpu_num = tpu_num
def __iter__(self):
image_choice = np.random.choice(
np.arange(self.N_img), self.i_validation, replace=True
)
idx_choice = [
np.random.choice(np.arange(self.N_pixels), self.batch_size) \
for _ in range(self.i_validation)
]
for (image_idx, idx) in zip(image_choice, idx_choice):
idx_ret = image_idx * self.N_pixels + idx
for ray_num in idx_ret:
yield ray_num
def __len__(self):
return self.i_validation * self.batch_size // self.tpu_num
class SingleImageDDPSampler(DistributedSampler):
def __init__(self, batch_size, N_img, N_pixels, i_validation):
self.batch_size = batch_size
self.N_pixels = N_pixels
self.N_img = N_img
self.drop_last = False
self.i_validation = i_validation
def __iter__(self):
image_choice = np.random.choice(
np.arange(self.N_img), self.i_validation, replace=True
)
idx_choice = [
np.random.choice(np.arange(self.N_pixels), self.batch_size) \
for _ in range(self.i_validation)
]
rank = dist.get_rank()
num_replicas = dist.get_world_size()
for (image_idx, idx) in zip(image_choice, idx_choice):
idx_ret = image_idx * self.N_pixels + idx
yield idx_ret[rank::num_replicas]
def __len__(self):
return self.i_validation
class MultipleImageSampler:
def __init__(self, batch_size, total_len, i_validation, tpu_num):
self.batch_size = batch_size
self.total_len = total_len
self.i_validation = i_validation
self.tpu_num = tpu_num
def __iter__(self):
full_index = np.arange(self.total_len)
indices = [
np.random.choice(full_index, self.batch_size) \
for _ in range(self.i_validation)
]
for batch in indices:
for idx in batch:
yield idx
def __len__(self):
return self.i_validation * self.batch_size
class MultipleImageDDPSampler:
def __init__(self, batch_size, total_len, i_validation):
self.batch_size = batch_size
self.total_len = total_len
self.i_validation = i_validation
def __iter__(self):
full_index = np.arange(self.total_len)
indices = [
np.random.choice(full_index, self.batch_size) \
for _ in range(self.i_validation)
]
rank = dist.get_rank()
num_replicas = dist.get_world_size()
for batch in indices:
yield batch[rank::num_replicas]
def __len__(self):
return self.i_validation
class RaySet(Dataset):
def __init__(self, images, rays):
self.images = images
self.rays = rays
self.N = len(images)
def __getitem__(self, index):
return {
"target": torch.from_numpy(self.images[index]),
"ray": torch.from_numpy(self.rays[index])
}
def __len__(self):
return len(self.images)