-
Notifications
You must be signed in to change notification settings - Fork 506
/
Copy pathparallel_loader.py
209 lines (177 loc) · 6.54 KB
/
parallel_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import itertools
import threading
import torch
import torch_xla
import torch_xla.debug.profiler as xp
import torch_xla.utils.keyd_queue as kq
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
class PerDeviceQueue(object):
def __init__(self, device, loader_prefetch_size, device_prefetch_size):
self.device = device
self.loader_queue = kq.Queue(maxsize=loader_prefetch_size)
self.queue = kq.Queue(maxsize=device_prefetch_size)
self.close_queue_count = itertools.count()
class PerDeviceLoader(object):
def __init__(self, loader, device):
self._loader = loader
self._device = device
self._mark_step_batch_count = loader.batches_per_execution - 1
self._batches_yielded = 0
def __iter__(self):
return self
def __next__(self):
return self.next()
def __len__(self):
return self._loader.per_device_samples()
def next(self):
if xp.get_tracer_marked_step():
xp.set_tracer_marked_step(False)
self._batches_yielded += 1
else:
if self._mark_step_batch_count <= self._batches_yielded:
self._batches_yielded = 0
xm.mark_step()
else:
self._batches_yielded += 1
item = self._loader.next_item(self._device)
if item is None:
xm.mark_step()
raise StopIteration
return item
class ParallelLoader(object):
"""Wraps an existing PyTorch DataLoader with background data upload.
Args:
loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
wrapped.
devices (`torch.device`...): The list of devices where the data has to be
sent. The i-th sample returned by the `loader` will be sent to `devices[i
% len(devices)]`.
batchdim (int, optional): The dimension which is holding the batch size.
Default: 0
loader_prefetch_size (int, optional): The max capacity of the queue used by
the thread which is reading samples from the `loader`, to be processed by
the worker threads which upload data to the devices.
Default: 8
device_prefetch_size (int, optional): The max size of the per-device queues,
where the worker threads deposit tensors which have already been sent to
devices.
Default: 4
host_to_device_transfer_threads (int, optional): The number of threads that
work in parallel to transfer data from loader queue to device queue.
Default: 1
input_sharding (ShardingSpec, optional): Sharding spec to apply to
compatible input tensors after loading.
Default: None
"""
def __init__(self,
loader,
devices,
batchdim=0,
batches_per_execution=1,
loader_prefetch_size=8,
device_prefetch_size=4,
host_to_device_transfer_threads=1,
input_sharding=None):
self._loader = loader
self._devices = [torch.device(x) for x in devices]
self._batchdim = batchdim
self._batches_per_execution = batches_per_execution
self._done = False
self._queues = dict()
self._input_sharding = input_sharding
for device in self._devices:
self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
device_prefetch_size)
thread = threading.Thread(target=self._loader_worker)
thread.daemon = True
thread.start()
for dqueue in self._queues.values():
for i in range(host_to_device_transfer_threads):
thread = threading.Thread(
target=self._worker,
args=(
dqueue,
host_to_device_transfer_threads,
))
thread.daemon = True
thread.start()
def per_device_loader(self, device):
"""Retrieves the loader iterator object for the given device.
Args:
device (`torch.device`): The device whole loader is being requested.
Returns:
The loader iterator object for the `device`. This is not a
`torch.utils.data.DataLoader` interface, but a Python iterator which
returns the same tensor data structure as returned by the wrapped
`torch.utils.data.DataLoader`, but residing on XLA devices.
"""
return PerDeviceLoader(self, torch.device(device))
def per_device_samples(self):
return len(self._loader) // len(self._devices)
def next_item(self, device):
dqueue = self._queues[device]
return dqueue.queue.get()
def close(self):
self._done = True
for dqueue in self._queues.values():
dqueue.queue.close()
dqueue.loader_queue.close()
@property
def batches_per_execution(self):
return self._batches_per_execution
def _loader_worker(self):
queues = list(self._queues.values())
data_iter = enumerate(self._loader)
batch = []
while not self._done:
try:
_, data = next(data_iter)
except StopIteration:
break
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].loader_queue.put(device_batch)
batch = []
for dqueue in queues:
dqueue.loader_queue.close_write()
def _get_batch(self, dqueue):
batch = []
while dqueue.queue.max_size() > len(batch):
item = dqueue.loader_queue.get()
if item is None:
break
batch.append(item)
return batch
def _worker(self, dqueue, host_to_device_transfer_threads):
device = torch.device(dqueue.device)
while True:
batch = self._get_batch(dqueue)
if not batch:
break
batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding)
for data in batch:
dqueue.queue.put(data)
close_queue_count = next(dqueue.close_queue_count)
if close_queue_count == host_to_device_transfer_threads - 1:
dqueue.queue.close_write()
class MpDeviceLoader(object):
"""Wraps an existing PyTorch DataLoader with background data upload.
This class should only be using with multi-processing data parallelism.
Args:
loader (:class:`torch.utils.data.DataLoader`): The PyTorch DataLoader to be
wrapped.
device (`torch.device`...): The device where the data has to be sent.
kwargs: Named arguments for the `ParallelLoader` constructor.
"""
def __init__(self, loader, device, **kwargs):
self._loader = loader
self._device = device
self._parallel_loader_kwargs = kwargs
def __iter__(self):
parallel_loader = ParallelLoader(self._loader, [self._device],
**self._parallel_loader_kwargs)
return parallel_loader.per_device_loader(self._device)
def __len__(self):
return len(self._loader)