simplify pool, refactor printing
diff --git a/pytest.py b/pytest.py
index a9e4f1e..4f3e867 100644
--- a/pytest.py
+++ b/pytest.py
@@ -82,48 +82,30 @@
def run_tests(args, printer, stats, test_names):
returncode = 0
running_jobs = set()
- pool = make_pool(args.jobs, run_test, args)
- pool_closed = False
stats.total = len(test_names)
+
+ pool = make_pool(args.jobs, run_test, args)
try:
while test_names or running_jobs:
while test_names and len(running_jobs) < args.jobs:
test_name = test_names.pop(0)
stats.started += 1
- if not args.quiet and printer.should_overwrite:
- printer.update(stats.format() + test_name,
- elide=(not args.verbose))
-
pool.send(test_name)
running_jobs.add(test_name)
+ print_test_started(printer, args, stats, test_name)
- if not test_names and not pool_closed:
- pool.close()
- pool_closed = True
-
- test_name, res, out, err = pool.get(block=True)
+ test_name, res, out, err = pool.get()
running_jobs.remove(test_name)
-
- stats.finished += 1
if res:
returncode = 1
- suffix = ' failed' + (':\n' if (out or err) else '')
- printer.update(stats.format() + test_name + suffix,
- elide=False)
- elif not args.quiet or out or err:
- suffix = ' passed' + (':' if (out or err) else '')
- printer.update(stats.format() + test_name + suffix,
- elide=(not out and not err))
- for l in out.splitlines():
- print_(' %s' % l)
- for l in err.splitlines():
- print_(' %s' % l, stream=sys.stderr)
+ stats.finished += 1
+ print_test_finished(printer, args, stats, test_name,
+ res, out, err)
finally:
- pool.terminate()
- pool.join()
+ pool.close()
if not args.quiet or returncode:
- print_('')
+ print_()
return returncode
@@ -141,9 +123,28 @@
return test_name, 0, result.out, result.err
-def print_(msg, end='\n', stream=sys.stdout):
+def print_test_started(printer, args, stats, test_name):
+ if not args.quiet and printer.should_overwrite:
+ printer.update(stats.format() + test_name, elide=(not args.verbose))
+
+
+def print_test_finished(printer, args, stats, test_name, res, out, err):
+ if res:
+ suffix = ' failed' + (':\n' if (out or err) else '')
+ printer.update(stats.format() + test_name + suffix, elide=False)
+ elif not args.quiet or out or err:
+ suffix = ' passed' + (':' if (out or err) else '')
+ printer.update(stats.format() + test_name + suffix,
+ elide=(not out and not err))
+ for l in out.splitlines():
+ print_(' %s' % l)
+ for l in err.splitlines():
+ print_(' %s' % l, stream=sys.stderr)
+
+
+def print_(msg='', end='\n', stream=sys.stdout):
stream.write(str(msg) + end)
- stream.write.flush()
+ stream.flush()
class PassThrough(StringIO.StringIO):
diff --git a/pytest_pool.py b/pytest_pool.py
index f3d143f..47ee429 100644
--- a/pytest_pool.py
+++ b/pytest_pool.py
@@ -35,12 +35,10 @@
self.requests.put((False, None))
self.requests.close()
- def terminate(self):
for w in self.workers:
w.terminate()
self.responses.close()
- def join(self):
for w in self.workers:
w.join()
@@ -60,12 +58,6 @@
def close(self):
pass
- def terminate(self):
- pass
-
- def join(self):
- pass
-
def _loop(_worker_num, callback, usrp, requests, responses):
try: