Merge pull request #26 from bukzor/try-else
fix try-else blocks.
diff --git a/mccabe.py b/mccabe.py
index 9738251..f7c1778 100644
--- a/mccabe.py
+++ b/mccabe.py
@@ -170,25 +170,29 @@
name = "If %d" % node.lineno
self._subgraph(node, name)
- def _subgraph(self, node, name):
+ def _subgraph(self, node, name, extra_blocks=()):
"""create the subgraphs representing any `if` and `for` statements"""
if self.graph is None:
# global loop
self.graph = PathGraph(name, name, node.lineno)
pathnode = PathNode(name)
- self._subgraph_parse(node, pathnode)
+ self._subgraph_parse(node, pathnode, extra_blocks)
self.graphs["%s%s" % (self.classname, name)] = self.graph
self.reset()
else:
pathnode = self.appendPathNode(name)
- self._subgraph_parse(node, pathnode)
+ self._subgraph_parse(node, pathnode, extra_blocks)
- def _subgraph_parse(self, node, pathnode):
+ def _subgraph_parse(self, node, pathnode, extra_blocks):
"""parse the body and any `else` block of `if` and `for` statements"""
loose_ends = []
self.tail = pathnode
self.dispatch_list(node.body)
loose_ends.append(self.tail)
+ for extra in extra_blocks:
+ self.tail = pathnode
+ self.dispatch_list(extra.body)
+ loose_ends.append(self.tail)
if node.orelse:
self.tail = pathnode
self.dispatch_list(node.orelse)
@@ -203,19 +207,9 @@
def visitTryExcept(self, node):
name = "TryExcept %d" % node.lineno
- pathnode = self.appendPathNode(name)
- loose_ends = []
- self.dispatch_list(node.body)
- loose_ends.append(self.tail)
- for handler in node.handlers:
- self.tail = pathnode
- self.dispatch_list(handler.body)
- loose_ends.append(self.tail)
- if pathnode:
- bottom = PathNode("", look='point')
- for le in loose_ends:
- self.graph.connect(le, bottom)
- self.tail = bottom
+ self._subgraph(node, name, extra_blocks=node.handlers)
+
+ visitTry = visitTryExcept
def visitWith(self, node):
name = "With %d" % node.lineno
diff --git a/test_mccabe.py b/test_mccabe.py
index 71a4ffb..89756b6 100644
--- a/test_mccabe.py
+++ b/test_mccabe.py
@@ -72,6 +72,17 @@
b()
"""
+try_else = """\
+try:
+ print(1)
+except TypeA:
+ print(2)
+except TypeB:
+ print(3)
+else:
+ print(4)
+"""
+
def get_complexity_number(snippet, strio, max=0):
"""Get the complexity number from the printed string."""
@@ -83,15 +94,26 @@
else:
return None
-
class McCabeTestCase(unittest.TestCase):
def setUp(self):
# If not assigned to sys.stdout then getvalue() won't capture anything.
+ self._orig_stdout = sys.stdout
sys.stdout = self.strio = StringIO()
def tearDown(self):
# https://mail.python.org/pipermail/tutor/2012-January/088031.html
self.strio.close()
+ sys.stdout = self._orig_stdout
+
+ def assert_complexity(self, snippet, max):
+ complexity = get_complexity_number(snippet, self.strio)
+ self.assertEqual(complexity, max)
+
+ # should have the same complexity when inside a function as well.
+ infunc = 'def f():\n ' + snippet.replace('\n', '\n ')
+ complexity = get_complexity_number(infunc, self.strio)
+ self.assertEqual(complexity, max)
+
def test_print_message(self):
get_code_complexity(sequential, 0)
@@ -138,6 +160,9 @@
printed_message = self.strio.getvalue()
self.assertEqual(printed_message, "")
+ def test_try_else(self):
+ self.assert_complexity(try_else, 4)
+
if __name__ == "__main__":
unittest.main()