Merge pull request #26 from bukzor/try-else

fix try-else blocks.
diff --git a/mccabe.py b/mccabe.py
index 9738251..f7c1778 100644
--- a/mccabe.py
+++ b/mccabe.py
@@ -170,25 +170,29 @@
         name = "If %d" % node.lineno
         self._subgraph(node, name)
 
-    def _subgraph(self, node, name):
+    def _subgraph(self, node, name, extra_blocks=()):
         """create the subgraphs representing any `if` and `for` statements"""
         if self.graph is None:
             # global loop
             self.graph = PathGraph(name, name, node.lineno)
             pathnode = PathNode(name)
-            self._subgraph_parse(node, pathnode)
+            self._subgraph_parse(node, pathnode, extra_blocks)
             self.graphs["%s%s" % (self.classname, name)] = self.graph
             self.reset()
         else:
             pathnode = self.appendPathNode(name)
-            self._subgraph_parse(node, pathnode)
+            self._subgraph_parse(node, pathnode, extra_blocks)
 
-    def _subgraph_parse(self, node, pathnode):
+    def _subgraph_parse(self, node, pathnode, extra_blocks):
         """parse the body and any `else` block of `if` and `for` statements"""
         loose_ends = []
         self.tail = pathnode
         self.dispatch_list(node.body)
         loose_ends.append(self.tail)
+        for extra in extra_blocks:
+            self.tail = pathnode
+            self.dispatch_list(extra.body)
+            loose_ends.append(self.tail)
         if node.orelse:
             self.tail = pathnode
             self.dispatch_list(node.orelse)
@@ -203,19 +207,9 @@
 
     def visitTryExcept(self, node):
         name = "TryExcept %d" % node.lineno
-        pathnode = self.appendPathNode(name)
-        loose_ends = []
-        self.dispatch_list(node.body)
-        loose_ends.append(self.tail)
-        for handler in node.handlers:
-            self.tail = pathnode
-            self.dispatch_list(handler.body)
-            loose_ends.append(self.tail)
-        if pathnode:
-            bottom = PathNode("", look='point')
-            for le in loose_ends:
-                self.graph.connect(le, bottom)
-            self.tail = bottom
+        self._subgraph(node, name, extra_blocks=node.handlers)
+
+    visitTry = visitTryExcept
 
     def visitWith(self, node):
         name = "With %d" % node.lineno
diff --git a/test_mccabe.py b/test_mccabe.py
index 71a4ffb..89756b6 100644
--- a/test_mccabe.py
+++ b/test_mccabe.py
@@ -72,6 +72,17 @@
     b()
 """
 
+try_else = """\
+try:
+    print(1)
+except TypeA:
+    print(2)
+except TypeB:
+    print(3)
+else:
+    print(4)
+"""
+
 
 def get_complexity_number(snippet, strio, max=0):
     """Get the complexity number from the printed string."""
@@ -83,15 +94,26 @@
     else:
         return None
 
-
 class McCabeTestCase(unittest.TestCase):
     def setUp(self):
         # If not assigned to sys.stdout then getvalue() won't capture anything.
+        self._orig_stdout = sys.stdout
         sys.stdout = self.strio = StringIO()
 
     def tearDown(self):
         # https://mail.python.org/pipermail/tutor/2012-January/088031.html
         self.strio.close()
+        sys.stdout = self._orig_stdout
+
+    def assert_complexity(self, snippet, max):
+        complexity = get_complexity_number(snippet, self.strio)
+        self.assertEqual(complexity, max)
+
+        # should have the same complexity when inside a function as well.
+        infunc = 'def f():\n    ' + snippet.replace('\n', '\n    ')
+        complexity = get_complexity_number(infunc, self.strio)
+        self.assertEqual(complexity, max)
+
 
     def test_print_message(self):
         get_code_complexity(sequential, 0)
@@ -138,6 +160,9 @@
         printed_message = self.strio.getvalue()
         self.assertEqual(printed_message, "")
 
+    def test_try_else(self):
+        self.assert_complexity(try_else, 4)
+
 
 if __name__ == "__main__":
     unittest.main()