Skip to content
This repository was archived by the owner on Jun 5, 2024. It is now read-only.

Commit 4eb241a

Browse files
committed
Added Multi-threading
1 parent ad0d716 commit 4eb241a

File tree

2 files changed

+599
-0
lines changed

2 files changed

+599
-0
lines changed

enqueuer.py

+282
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# coding=utf-8
2+
"""Given the dataset object, make a multiprocess/thread enqueuer"""
3+
import os
4+
import queue
5+
import threading
6+
import contextlib
7+
import multiprocessing
8+
import time
9+
import random
10+
import sys
11+
import utils
12+
import traceback
13+
import numpy as np
14+
15+
# TODo: checkout https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader
16+
# ------------------------------- the following is only needed for multiprocess
17+
# multiprocess is only good for video inputs (num_workers=num_core)
18+
# multithreading is good enough for frame inputs
19+
# and somehow the optimal num_workers=4, for many kinds of machine with threads
20+
21+
# Global variables to be shared across processes
22+
_SHARED_DATASETS = {}
23+
# We use a Value to provide unique id to different processes.
24+
_SEQUENCE_COUNTER = None
25+
# Because multiprocessing pools are inherently unsafe, starting from a clean
26+
# state can be essential to avoiding deadlocks. In order to accomplish this, we
27+
# need to be able to check on the status of Pools that we create.
28+
_WORKER_ID_QUEUE = None # Only created if needed.
29+
30+
# modified from keras
31+
class DatasetEnqueuer(object):
32+
def __init__(self, dataset, prefetch=5, num_workers=1,
33+
start=True, # start the dataset get thread when init
34+
shuffle=False,
35+
# whether to break down each mini-batch for each gpu
36+
is_multi_gpu=False,
37+
last_full_batch=False, # make sure the last batch is full
38+
use_process=False, # use process instead of thread
39+
):
40+
self.dataset = dataset
41+
42+
self.prefetch = prefetch # how many batch to save in queue
43+
self.max_queue_size = int(self.prefetch * dataset.batch_size)
44+
45+
self.workers = num_workers
46+
self.queue = None
47+
self.run_thread = None # the thread to spawn others
48+
self.stop_signal = None
49+
50+
self.cur_batch_count = 0
51+
52+
self.shuffle = shuffle
53+
54+
self.use_process = use_process
55+
56+
self.is_multi_gpu = is_multi_gpu
57+
self.last_full_batch = last_full_batch
58+
59+
# need to have a global uid for each enqueuer so we could use train/val
60+
# at the same time
61+
global _SEQUENCE_COUNTER
62+
if _SEQUENCE_COUNTER is None:
63+
try:
64+
_SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
65+
except OSError:
66+
# In this case the OS does not allow us to use
67+
# multiprocessing. We resort to an int
68+
# for enqueuer indexing.
69+
_SEQUENCE_COUNTER = 0
70+
71+
if isinstance(_SEQUENCE_COUNTER, int):
72+
self.uid = _SEQUENCE_COUNTER
73+
_SEQUENCE_COUNTER += 1
74+
else:
75+
# Doing Multiprocessing.Value += x is not process-safe.
76+
with _SEQUENCE_COUNTER.get_lock():
77+
self.uid = _SEQUENCE_COUNTER.value
78+
_SEQUENCE_COUNTER.value += 1
79+
80+
if start:
81+
self.start()
82+
83+
def is_running(self):
84+
return self.stop_signal is not None and not self.stop_signal.is_set()
85+
86+
def start(self):
87+
if self.use_process:
88+
self.executor_fn = self._get_executor_init(self.workers)
89+
else:
90+
self.executor_fn = lambda _: multiprocessing.pool.ThreadPool(self.workers)
91+
92+
self.queue = queue.Queue(self.max_queue_size)
93+
self.stop_signal = threading.Event()
94+
95+
self.run_thread = threading.Thread(target=self._run)
96+
self.run_thread.daemon = True
97+
self.run_thread.start()
98+
99+
def _get_executor_init(self, workers):
100+
"""Gets the Pool initializer for multiprocessing.
101+
102+
Arguments:
103+
workers: Number of workers.
104+
105+
Returns:
106+
Function, a Function to initialize the pool
107+
"""
108+
def pool_fn(seqs):
109+
pool = multiprocessing.Pool(
110+
workers, initializer=init_pool_generator,
111+
initargs=(seqs, None, get_worker_id_queue()))
112+
return pool
113+
114+
return pool_fn
115+
116+
def stop(self):
117+
#print("stop called")
118+
if self.is_running():
119+
self._stop()
120+
121+
def _stop(self):
122+
#print("_stop called")
123+
self.stop_signal.set()
124+
with self.queue.mutex:
125+
self.queue.queue.clear()
126+
self.queue.unfinished_tasks = 0
127+
self.queue.not_full.notify()
128+
129+
self.run_thread.join(0)
130+
131+
_SHARED_DATASETS[self.uid] = None
132+
133+
def __del__(self):
134+
if self.is_running():
135+
self._stop()
136+
137+
def _send_dataset(self):
138+
"""Sends current Iterable to all workers."""
139+
# For new processes that may spawn
140+
_SHARED_DATASETS[self.uid] = self.dataset
141+
142+
# preprocess the data and put them into queue
143+
def _run(self):
144+
batch_idxs = list(self.dataset.valid_idxs) * self.dataset.num_epochs
145+
146+
if self.shuffle:
147+
batch_idxs = random.sample(batch_idxs, len(batch_idxs))
148+
batch_idxs = random.sample(batch_idxs, len(batch_idxs))
149+
150+
if self.last_full_batch:
151+
# make sure the batch_idxs are multiplier of batch_size
152+
batch_idxs += [batch_idxs[-1] for _ in range(
153+
self.dataset.batch_size - len(batch_idxs) % self.dataset.batch_size)]
154+
155+
self._send_dataset() # Share the initial dataset
156+
157+
while True:
158+
#with contextlib.closing(
159+
# multiprocessing.pool.ThreadPool(self.workers)) as executor:
160+
with contextlib.closing(
161+
self.executor_fn(_SHARED_DATASETS)) as executor:
162+
for idx in batch_idxs:
163+
if self.stop_signal.is_set():
164+
return
165+
# block until not full
166+
#self.queue.put(
167+
# executor.apply_async(self.dataset.get_sample, (idx,)), block=True)
168+
self.queue.put(
169+
executor.apply_async(get_index, (self.uid, idx)), block=True)
170+
171+
self._wait_queue()
172+
if self.stop_signal.is_set():
173+
# We're done
174+
return
175+
176+
self._send_dataset() # Update the pool
177+
178+
# get batch from the queue
179+
# toDo: this is single thread, put the batch collecting into multi-thread
180+
def get(self):
181+
if not self.is_running():
182+
self.start()
183+
try:
184+
while self.is_running():
185+
if self.cur_batch_count == self.dataset.num_batches:
186+
self._stop()
187+
return
188+
189+
samples = []
190+
for i in range(self.dataset.batch_size):
191+
# first get got the ApplyResult object,
192+
# then second get to get the actual thing (block till get)
193+
sample = self.queue.get(block=True).get()
194+
self.queue.task_done()
195+
samples.append(sample)
196+
197+
# break the mini-batch into mini-batches for multi-gpu
198+
if self.is_multi_gpu:
199+
# a list of [frames, boxes, labels_arr, ori_boxes, box_keys]
200+
batches = []
201+
202+
this_batch_idxs = range(len(samples))
203+
204+
# pack these batches for each gpu
205+
this_batch_idxs_gpus = utils.grouper(
206+
this_batch_idxs, self.dataset.batch_size_per_gpu)
207+
batches = []
208+
for this_batch_idxs_per_gpu in this_batch_idxs_gpus:
209+
batches.append(self.dataset.collect_batch(
210+
samples, this_batch_idxs_per_gpu))
211+
212+
batch = batches
213+
else:
214+
batch = self.dataset.collect_batch(samples)
215+
216+
217+
self.cur_batch_count += 1
218+
yield batch
219+
220+
except Exception as e: # pylint: disable=broad-except
221+
self._stop()
222+
_type, _value, _traceback = sys.exc_info()
223+
print("Exception in enqueuer.get: %s" % e)
224+
traceback.print_tb(_traceback)
225+
raise Exception
226+
227+
def _wait_queue(self):
228+
"""Wait for the queue to be empty."""
229+
while True:
230+
time.sleep(0.1)
231+
if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set():
232+
return
233+
234+
235+
def get_worker_id_queue():
236+
"""Lazily create the queue to track worker ids."""
237+
global _WORKER_ID_QUEUE
238+
if _WORKER_ID_QUEUE is None:
239+
_WORKER_ID_QUEUE = multiprocessing.Queue()
240+
return _WORKER_ID_QUEUE
241+
242+
def get_index(uid, i):
243+
"""Get the value from the Ddataset `uid` at index `i`.
244+
245+
To allow multiple Sequences to be used at the same time, we use `uid` to
246+
get a specific one. A single Sequence would cause the validation to
247+
overwrite the training Sequence.
248+
249+
Arguments:
250+
uid: int, Sequence identifier
251+
i: index
252+
253+
Returns:
254+
The value at index `i`.
255+
"""
256+
return _SHARED_DATASETS[uid].get_sample(i)
257+
258+
def init_pool_generator(gens, random_seed=None, id_queue=None):
259+
"""Initializer function for pool workers.
260+
261+
Args:
262+
gens: State which should be made available to worker processes.
263+
random_seed: An optional value with which to seed child processes.
264+
id_queue: A multiprocessing Queue of worker ids. This is used to indicate
265+
that a worker process was created by Keras and can be terminated using
266+
the cleanup_all_keras_forkpools utility.
267+
"""
268+
global _SHARED_DATASETS
269+
_SHARED_DATASETS = gens
270+
271+
worker_proc = multiprocessing.current_process()
272+
273+
# name isn't used for anything, but setting a more descriptive name is helpful
274+
# when diagnosing orphaned processes.
275+
worker_proc.name = 'Enqueuer_worker_{}'.format(worker_proc.name)
276+
277+
if random_seed is not None:
278+
np.random.seed(random_seed + worker_proc.ident)
279+
280+
if id_queue is not None:
281+
# If a worker dies during init, the pool will just create a replacement.
282+
id_queue.put(worker_proc.ident, block=True, timeout=0.1)

0 commit comments

Comments
 (0)