Rework the message passing in pool to have better error handling.
diff --git a/typ/pool.py b/typ/pool.py
index ea0d0d2..79f1cf8 100644
--- a/typ/pool.py
+++ b/typ/pool.py
@@ -13,16 +13,19 @@
# limitations under the License.
import copy
+import enum
import multiprocessing
from typ.host import Host
-try:
- # This gets compatibility for both Python 2 and Python 3.
- # import failure ... pylint: disable=F0401
- from queue import Empty
-except ImportError:
- from Queue import Empty
+
+class _MessageType(enum.Enum):
+ # Class has no __init__ pylint: disable=W0232
+ Request = 1
+ Response = 2
+ Close = 3
+ Done = 4
+ Error = 5
def make_pool(host, jobs, callback, context, pre_fn, post_fn):
@@ -38,6 +41,7 @@
self.responses = multiprocessing.Queue()
self.workers = []
self.closed = False
+ self.erred = False
for worker_num in range(jobs):
w = multiprocessing.Process(target=_loop,
args=(self.requests, self.responses,
@@ -48,19 +52,23 @@
self.workers.append(w)
def send(self, msg):
- self.requests.put((True, msg))
+ self.requests.put((_MessageType.Request, msg))
def get(self, block=True, timeout=None):
- return self.responses.get(block, timeout)
+ msg_type, resp = self.responses.get(block, timeout)
+ if msg_type == _MessageType.Error: # pragma: no cover
+ self._handle_error(resp)
+ assert msg_type == _MessageType.Response
+ return resp
def close(self):
for _ in self.workers:
- self.requests.put((False, None))
+ self.requests.put((_MessageType.Close, None))
self.requests.close()
self.closed = True
def join(self):
- final_contexts = []
+ final_responses = []
if not self.closed:
self.requests.close()
for w in self.workers:
@@ -68,10 +76,22 @@
w.join()
else:
for w in self.workers:
- final_contexts.append(self.responses.get(True))
+ while True:
+ msg_type, resp = self.responses.get(True)
+ if msg_type == _MessageType.Error: # pragma: no cover
+ self._handle_error(resp)
+ elif msg_type == _MessageType.Done:
+ break
+ # TODO: log something about discarding messages?
+ final_responses.append(resp)
w.join()
self.responses.close()
- return final_contexts
+ return final_responses
+
+ def _handle_error(self, msg): # pragma: no cover
+ worker_num, ex_str = msg
+ self.erred = True
+ raise Exception("error from worker %d: %s" % (worker_num, ex_str))
class AsyncPool(object):
@@ -107,20 +127,22 @@
callback, context, pre_fn, post_fn): # pragma: no cover
# TODO: Figure out how to get coverage to work w/ subprocesses.
host = host or Host()
+ erred = False
try:
context_after_pre = pre_fn(host, worker_num, context)
- keep_going = True
- while keep_going:
- keep_going, args = requests.get(block=True)
- if keep_going:
- resp = callback(context_after_pre, args)
- responses.put(resp)
- except Empty:
+ while True:
+ message_type, args = requests.get(block=True)
+ if message_type == _MessageType.Close:
+ break
+ assert message_type == _MessageType.Request
+ resp = callback(context_after_pre, args)
+ responses.put((_MessageType.Response, resp))
+ except Exception as e:
+ erred = True
+ responses.put((_MessageType.Error, (worker_num, str(e))))
+
+ try:
+ if not erred:
+ responses.put((_MessageType.Done, post_fn(context_after_pre)))
+ except Exception:
pass
- except IOError:
- pass
- except KeyboardInterrupt:
- pass
- finally:
- # TODO: Figure out how to propagate errors back.
- responses.put(post_fn(context_after_pre))