Make test_ssl pass in an IPv6-only environment (#827)

* Make test_ssl pass in an IPv6-only environment

* Review comments

* Update tests/test_ssl.py

Co-Authored-By: davidben <davidben@davidben.net>

* Wrap long line with parens.
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index ed911de..362da5c 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -10,9 +10,10 @@
 import uuid
 
 from gc import collect, get_referrers
-from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN
+from errno import (
+    EAFNOSUPPORT, ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN)
 from sys import platform, getfilesystemencoding
-from socket import MSG_PEEK, SHUT_RDWR, error, socket
+from socket import AF_INET, AF_INET6, MSG_PEEK, SHUT_RDWR, error, socket
 from os import makedirs
 from os.path import join
 from weakref import ref
@@ -101,6 +102,23 @@
 skip_if_py3 = pytest.mark.skipif(PY3, reason="Python 2 only")
 
 
+def socket_any_family():
+    try:
+        return socket(AF_INET)
+    except error as e:
+        if e.errno == EAFNOSUPPORT:
+            return socket(AF_INET6)
+        raise
+
+
+def loopback_address(socket):
+    if socket.family == AF_INET:
+        return "127.0.0.1"
+    else:
+        assert socket.family == AF_INET6
+        return "::1"
+
+
 def join_bytes_or_unicode(prefix, suffix):
     """
     Join two path components of either ``bytes`` or ``unicode``.
@@ -127,12 +145,12 @@
     Establish and return a pair of network sockets connected to each other.
     """
     # Connect a pair of sockets
-    port = socket()
+    port = socket_any_family()
     port.bind(('', 0))
     port.listen(1)
-    client = socket()
+    client = socket(port.family)
     client.setblocking(False)
-    client.connect_ex(("127.0.0.1", port.getsockname()[1]))
+    client.connect_ex((loopback_address(port), port.getsockname()[1]))
     client.setblocking(True)
     server = port.accept()[0]
 
@@ -1209,7 +1227,7 @@
             VERIFY_PEER,
             lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
 
-        client = socket()
+        client = socket_any_family()
         client.connect(("encrypted.google.com", 443))
         clientSSL = Connection(context, client)
         clientSSL.set_connect_state()
@@ -2237,7 +2255,7 @@
         `Connection.connect` raises `TypeError` if called with a non-address
         argument.
         """
-        connection = Connection(Context(TLSv1_METHOD), socket())
+        connection = Connection(Context(TLSv1_METHOD), socket_any_family())
         with pytest.raises(TypeError):
             connection.connect(None)
 
@@ -2246,13 +2264,13 @@
         `Connection.connect` raises `socket.error` if the underlying socket
         connect method raises it.
         """
-        client = socket()
+        client = socket_any_family()
         context = Context(TLSv1_METHOD)
         clientSSL = Connection(context, client)
         # pytest.raises here doesn't work because of a bug in py.test on Python
         # 2.6: https://github.com/pytest-dev/pytest/issues/988
         try:
-            clientSSL.connect(("127.0.0.1", 1))
+            clientSSL.connect((loopback_address(client), 1))
         except error as e:
             exc = e
         assert exc.args[0] == ECONNREFUSED
@@ -2261,12 +2279,12 @@
         """
         `Connection.connect` establishes a connection to the specified address.
         """
-        port = socket()
+        port = socket_any_family()
         port.bind(('', 0))
         port.listen(3)
 
-        clientSSL = Connection(Context(TLSv1_METHOD), socket())
-        clientSSL.connect(('127.0.0.1', port.getsockname()[1]))
+        clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family))
+        clientSSL.connect((loopback_address(port), port.getsockname()[1]))
         # XXX An assertion?  Or something?
 
     @pytest.mark.skipif(
@@ -2278,11 +2296,11 @@
         If there is a connection error, `Connection.connect_ex` returns the
         errno instead of raising an exception.
         """
-        port = socket()
+        port = socket_any_family()
         port.bind(('', 0))
         port.listen(3)
 
-        clientSSL = Connection(Context(TLSv1_METHOD), socket())
+        clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family))
         clientSSL.setblocking(False)
         result = clientSSL.connect_ex(port.getsockname())
         expected = (EINPROGRESS, EWOULDBLOCK)
@@ -2297,16 +2315,16 @@
         ctx = Context(TLSv1_METHOD)
         ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
         ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
-        port = socket()
+        port = socket_any_family()
         portSSL = Connection(ctx, port)
         portSSL.bind(('', 0))
         portSSL.listen(3)
 
-        clientSSL = Connection(Context(TLSv1_METHOD), socket())
+        clientSSL = Connection(Context(TLSv1_METHOD), socket(port.family))
 
         # Calling portSSL.getsockname() here to get the server IP address
         # sounds great, but frequently fails on Windows.
-        clientSSL.connect(('127.0.0.1', portSSL.getsockname()[1]))
+        clientSSL.connect((loopback_address(port), portSSL.getsockname()[1]))
 
         serverSSL, address = portSSL.accept()
 
@@ -2379,7 +2397,7 @@
         `Connection.set_shutdown` sets the state of the SSL connection
         shutdown process.
         """
-        connection = Connection(Context(TLSv1_METHOD), socket())
+        connection = Connection(Context(TLSv1_METHOD), socket_any_family())
         connection.set_shutdown(RECEIVED_SHUTDOWN)
         assert connection.get_shutdown() == RECEIVED_SHUTDOWN
 
@@ -2389,7 +2407,7 @@
         On Python 2 `Connection.set_shutdown` accepts an argument
         of type `long` as well as `int`.
         """
-        connection = Connection(Context(TLSv1_METHOD), socket())
+        connection = Connection(Context(TLSv1_METHOD), socket_any_family())
         connection.set_shutdown(long(RECEIVED_SHUTDOWN))
         assert connection.get_shutdown() == RECEIVED_SHUTDOWN
 
@@ -3503,7 +3521,7 @@
         work on `OpenSSL.SSL.Connection`() that use sockets.
         """
         context = Context(TLSv1_METHOD)
-        client = socket()
+        client = socket_any_family()
         clientSSL = Connection(context, client)
         with pytest.raises(TypeError):
             clientSSL.bio_read(100)