Add tests for TLS fallback on connection reset and close.

The reset tests currently fail on OpenSSL and do not work on Android. But
otherwise this gives us slightly better test coverage here.

BUG=372849

Review URL: https://codereview.chromium.org/342793003

git-svn-id: http://src.chromium.org/svn/trunk/src/third_party/tlslite@280188 4ff67af0-8c30-449e-8e8b-ad334ec8d88c
diff --git a/README.chromium b/README.chromium
index 5487bfc..ab0b941 100644
--- a/README.chromium
+++ b/README.chromium
@@ -35,3 +35,5 @@
   certificate_types field of CertificateRequest.
 - patches/ignore_write_failure.patch: Don't invalidate sessions on write
   failures.
+- patches/intolerance_options.patch: Add an option to further control
+  simulated TLS version intolerance.
diff --git a/patches/intolerance_options.patch b/patches/intolerance_options.patch
new file mode 100644
index 0000000..0bee1d6
--- /dev/null
+++ b/patches/intolerance_options.patch
@@ -0,0 +1,192 @@
+diff --git a/third_party/tlslite/tlslite/handshakesettings.py b/third_party/tlslite/tlslite/handshakesettings.py
+index e0bc0e6..0d4ccf2 100644
+--- a/third_party/tlslite/tlslite/handshakesettings.py
++++ b/third_party/tlslite/tlslite/handshakesettings.py
+@@ -18,6 +18,7 @@ ALL_MAC_NAMES = ["sha", "md5"]
+ KEY_EXCHANGE_NAMES = ["rsa", "dhe_rsa", "srp_sha", "srp_sha_rsa", "dh_anon"]
+ CIPHER_IMPLEMENTATIONS = ["openssl", "pycrypto", "python"]
+ CERTIFICATE_TYPES = ["x509"]
++TLS_INTOLERANCE_TYPES = ["alert", "close", "reset"]
+ 
+ class HandshakeSettings(object):
+     """This class encapsulates various parameters that can be used with
+@@ -92,6 +93,21 @@ class HandshakeSettings(object):
+     The default is (3,2).  (WARNING: Some servers may (improperly)
+     reject clients which offer support for TLS 1.1.  In this case,
+     try lowering maxVersion to (3,1)).
++
++    @type tlsIntolerant: tuple
++    @ivar tlsIntolerant: The TLS ClientHello version which the server
++    simulates intolerance of.
++
++    If tlsIntolerant is not None, the server will simulate TLS version
++    intolerance by aborting the handshake in response to all TLS versions
++    tlsIntolerant or higher.
++
++    @type tlsIntoleranceType: str
++    @ivar tlsIntoleranceType: How the server should react when simulating TLS
++    intolerance.
++
++    The allowed values are "alert" (return a fatal handshake_failure alert),
++    "close" (abruptly close the connection), and "reset" (send a TCP reset).
+     
+     @type useExperimentalTackExtension: bool
+     @ivar useExperimentalTackExtension: Whether to enabled TACK support.
+@@ -109,6 +125,8 @@ class HandshakeSettings(object):
+         self.certificateTypes = CERTIFICATE_TYPES
+         self.minVersion = (3,0)
+         self.maxVersion = (3,2)
++        self.tlsIntolerant = None
++        self.tlsIntoleranceType = 'alert'
+         self.useExperimentalTackExtension = False
+ 
+     # Validates the min/max fields, and certificateTypes
+@@ -124,6 +142,8 @@ class HandshakeSettings(object):
+         other.certificateTypes = self.certificateTypes
+         other.minVersion = self.minVersion
+         other.maxVersion = self.maxVersion
++        other.tlsIntolerant = self.tlsIntolerant
++        other.tlsIntoleranceType = self.tlsIntoleranceType
+ 
+         if not cipherfactory.tripleDESPresent:
+             other.cipherNames = [e for e in self.cipherNames if e != "3des"]
+@@ -165,6 +185,10 @@ class HandshakeSettings(object):
+             if s not in CERTIFICATE_TYPES:
+                 raise ValueError("Unknown certificate type: '%s'" % s)
+ 
++        if other.tlsIntoleranceType not in TLS_INTOLERANCE_TYPES:
++            raise ValueError(
++                "Unknown TLS intolerance type: '%s'" % other.tlsIntoleranceType)
++
+         if other.minVersion > other.maxVersion:
+             raise ValueError("Versions set incorrectly")
+ 
+diff --git a/third_party/tlslite/tlslite/tlsconnection.py b/third_party/tlslite/tlslite/tlsconnection.py
+index 044ad59..7c1572f 100644
+--- a/third_party/tlslite/tlslite/tlsconnection.py
++++ b/third_party/tlslite/tlslite/tlsconnection.py
+@@ -1065,7 +1065,7 @@ class TLSConnection(TLSRecordLayer):
+                         reqCAs = None, reqCertTypes = None,
+                         tacks=None, activationFlags=0,
+                         nextProtos=None, anon=False,
+-                        tlsIntolerant=None, signedCertTimestamps=None,
++                        signedCertTimestamps=None,
+                         fallbackSCSV=False, ocspResponse=None):
+         """Perform a handshake in the role of server.
+ 
+@@ -1139,11 +1139,6 @@ class TLSConnection(TLSRecordLayer):
+         clients through the Next-Protocol Negotiation Extension, 
+         if they support it.
+ 
+-        @type tlsIntolerant: (int, int) or None
+-        @param tlsIntolerant: If tlsIntolerant is not None, the server will
+-        simulate TLS version intolerance by returning a fatal handshake_failure
+-        alert to all TLS versions tlsIntolerant or higher.
+-
+         @type signedCertTimestamps: str
+         @param signedCertTimestamps: A SignedCertificateTimestampList (as a
+         binary 8-bit string) that will be sent as a TLS extension whenever
+@@ -1175,7 +1170,7 @@ class TLSConnection(TLSRecordLayer):
+                 certChain, privateKey, reqCert, sessionCache, settings,
+                 checker, reqCAs, reqCertTypes,
+                 tacks=tacks, activationFlags=activationFlags, 
+-                nextProtos=nextProtos, anon=anon, tlsIntolerant=tlsIntolerant,
++                nextProtos=nextProtos, anon=anon,
+                 signedCertTimestamps=signedCertTimestamps,
+                 fallbackSCSV=fallbackSCSV, ocspResponse=ocspResponse):
+             pass
+@@ -1187,7 +1182,6 @@ class TLSConnection(TLSRecordLayer):
+                              reqCAs=None, reqCertTypes=None,
+                              tacks=None, activationFlags=0,
+                              nextProtos=None, anon=False,
+-                             tlsIntolerant=None,
+                              signedCertTimestamps=None,
+                              fallbackSCSV=False,
+                              ocspResponse=None
+@@ -1210,7 +1204,6 @@ class TLSConnection(TLSRecordLayer):
+             reqCAs=reqCAs, reqCertTypes=reqCertTypes,
+             tacks=tacks, activationFlags=activationFlags, 
+             nextProtos=nextProtos, anon=anon,
+-            tlsIntolerant=tlsIntolerant,
+             signedCertTimestamps=signedCertTimestamps,
+             fallbackSCSV=fallbackSCSV,
+             ocspResponse=ocspResponse)
+@@ -1223,7 +1216,7 @@ class TLSConnection(TLSRecordLayer):
+                              settings, reqCAs, reqCertTypes,
+                              tacks, activationFlags, 
+                              nextProtos, anon,
+-                             tlsIntolerant, signedCertTimestamps, fallbackSCSV,
++                             signedCertTimestamps, fallbackSCSV,
+                              ocspResponse):
+ 
+         self._handshakeStart(client=False)
+@@ -1261,7 +1254,7 @@ class TLSConnection(TLSRecordLayer):
+         # Handle ClientHello and resumption
+         for result in self._serverGetClientHello(settings, certChain,\
+                                             verifierDB, sessionCache,
+-                                            anon, tlsIntolerant, fallbackSCSV):
++                                            anon, fallbackSCSV):
+             if result in (0,1): yield result
+             elif result == None:
+                 self._handshakeDone(resumed=True)                
+@@ -1376,7 +1369,7 @@ class TLSConnection(TLSRecordLayer):
+ 
+ 
+     def _serverGetClientHello(self, settings, certChain, verifierDB,
+-                                sessionCache, anon, tlsIntolerant, fallbackSCSV):
++                                sessionCache, anon, fallbackSCSV):
+         #Initialize acceptable cipher suites
+         cipherSuites = []
+         if verifierDB:
+@@ -1413,11 +1406,21 @@ class TLSConnection(TLSRecordLayer):
+                 yield result
+ 
+         #If simulating TLS intolerance, reject certain TLS versions.
+-        elif (tlsIntolerant is not None and
+-            clientHello.client_version >= tlsIntolerant):
+-            for result in self._sendError(\
++        elif (settings.tlsIntolerant is not None and
++              clientHello.client_version >= settings.tlsIntolerant):
++            if settings.tlsIntoleranceType == "alert":
++                for result in self._sendError(\
+                     AlertDescription.handshake_failure):
+-                yield result
++                    yield result
++            elif settings.tlsIntoleranceType == "close":
++                self._abruptClose()
++                raise TLSUnsupportedError("Simulating version intolerance")
++            elif settings.tlsIntoleranceType == "reset":
++                self._abruptClose(reset=True)
++                raise TLSUnsupportedError("Simulating version intolerance")
++            else:
++                raise ValueError("Unknown intolerance type: '%s'" %
++                                 settings.tlsIntoleranceType)
+ 
+         #If client's version is too high, propose my highest version
+         elif clientHello.client_version > settings.maxVersion:
+diff --git a/third_party/tlslite/tlslite/tlsrecordlayer.py b/third_party/tlslite/tlslite/tlsrecordlayer.py
+index 370dc9a..23c2a2f 100644
+--- a/third_party/tlslite/tlslite/tlsrecordlayer.py
++++ b/third_party/tlslite/tlslite/tlsrecordlayer.py
+@@ -19,6 +19,7 @@ from .constants import *
+ from .utils.cryptomath import getRandomBytes
+ 
+ import socket
++import struct
+ import errno
+ import traceback
+ 
+@@ -523,6 +524,13 @@ class TLSRecordLayer(object):
+         self._shutdown(False)
+         raise TLSLocalAlert(alert, errorStr)
+ 
++    def _abruptClose(self, reset=False):
++        if reset:
++            #Set an SO_LINGER timeout of 0 to send a TCP RST.
++            self.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
++                            struct.pack('ii', 1, 0))
++        self._shutdown(False)
++
+     def _sendMsgs(self, msgs):
+         randomizeFirstBlock = True
+         for msg in msgs:
diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py
index e0bc0e6..0d4ccf2 100644
--- a/tlslite/handshakesettings.py
+++ b/tlslite/handshakesettings.py
@@ -18,6 +18,7 @@
 KEY_EXCHANGE_NAMES = ["rsa", "dhe_rsa", "srp_sha", "srp_sha_rsa", "dh_anon"]
 CIPHER_IMPLEMENTATIONS = ["openssl", "pycrypto", "python"]
 CERTIFICATE_TYPES = ["x509"]
+TLS_INTOLERANCE_TYPES = ["alert", "close", "reset"]
 
 class HandshakeSettings(object):
     """This class encapsulates various parameters that can be used with
@@ -92,6 +93,21 @@
     The default is (3,2).  (WARNING: Some servers may (improperly)
     reject clients which offer support for TLS 1.1.  In this case,
     try lowering maxVersion to (3,1)).
+
+    @type tlsIntolerant: tuple
+    @ivar tlsIntolerant: The TLS ClientHello version which the server
+    simulates intolerance of.
+
+    If tlsIntolerant is not None, the server will simulate TLS version
+    intolerance by aborting the handshake in response to all TLS versions
+    tlsIntolerant or higher.
+
+    @type tlsIntoleranceType: str
+    @ivar tlsIntoleranceType: How the server should react when simulating TLS
+    intolerance.
+
+    The allowed values are "alert" (return a fatal handshake_failure alert),
+    "close" (abruptly close the connection), and "reset" (send a TCP reset).
     
     @type useExperimentalTackExtension: bool
     @ivar useExperimentalTackExtension: Whether to enabled TACK support.
@@ -109,6 +125,8 @@
         self.certificateTypes = CERTIFICATE_TYPES
         self.minVersion = (3,0)
         self.maxVersion = (3,2)
+        self.tlsIntolerant = None
+        self.tlsIntoleranceType = 'alert'
         self.useExperimentalTackExtension = False
 
     # Validates the min/max fields, and certificateTypes
@@ -124,6 +142,8 @@
         other.certificateTypes = self.certificateTypes
         other.minVersion = self.minVersion
         other.maxVersion = self.maxVersion
+        other.tlsIntolerant = self.tlsIntolerant
+        other.tlsIntoleranceType = self.tlsIntoleranceType
 
         if not cipherfactory.tripleDESPresent:
             other.cipherNames = [e for e in self.cipherNames if e != "3des"]
@@ -165,6 +185,10 @@
             if s not in CERTIFICATE_TYPES:
                 raise ValueError("Unknown certificate type: '%s'" % s)
 
+        if other.tlsIntoleranceType not in TLS_INTOLERANCE_TYPES:
+            raise ValueError(
+                "Unknown TLS intolerance type: '%s'" % other.tlsIntoleranceType)
+
         if other.minVersion > other.maxVersion:
             raise ValueError("Versions set incorrectly")
 
diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py
index 044ad59..7c1572f 100644
--- a/tlslite/tlsconnection.py
+++ b/tlslite/tlsconnection.py
@@ -1065,7 +1065,7 @@
                         reqCAs = None, reqCertTypes = None,
                         tacks=None, activationFlags=0,
                         nextProtos=None, anon=False,
-                        tlsIntolerant=None, signedCertTimestamps=None,
+                        signedCertTimestamps=None,
                         fallbackSCSV=False, ocspResponse=None):
         """Perform a handshake in the role of server.
 
@@ -1139,11 +1139,6 @@
         clients through the Next-Protocol Negotiation Extension, 
         if they support it.
 
-        @type tlsIntolerant: (int, int) or None
-        @param tlsIntolerant: If tlsIntolerant is not None, the server will
-        simulate TLS version intolerance by returning a fatal handshake_failure
-        alert to all TLS versions tlsIntolerant or higher.
-
         @type signedCertTimestamps: str
         @param signedCertTimestamps: A SignedCertificateTimestampList (as a
         binary 8-bit string) that will be sent as a TLS extension whenever
@@ -1175,7 +1170,7 @@
                 certChain, privateKey, reqCert, sessionCache, settings,
                 checker, reqCAs, reqCertTypes,
                 tacks=tacks, activationFlags=activationFlags, 
-                nextProtos=nextProtos, anon=anon, tlsIntolerant=tlsIntolerant,
+                nextProtos=nextProtos, anon=anon,
                 signedCertTimestamps=signedCertTimestamps,
                 fallbackSCSV=fallbackSCSV, ocspResponse=ocspResponse):
             pass
@@ -1187,7 +1182,6 @@
                              reqCAs=None, reqCertTypes=None,
                              tacks=None, activationFlags=0,
                              nextProtos=None, anon=False,
-                             tlsIntolerant=None,
                              signedCertTimestamps=None,
                              fallbackSCSV=False,
                              ocspResponse=None
@@ -1210,7 +1204,6 @@
             reqCAs=reqCAs, reqCertTypes=reqCertTypes,
             tacks=tacks, activationFlags=activationFlags, 
             nextProtos=nextProtos, anon=anon,
-            tlsIntolerant=tlsIntolerant,
             signedCertTimestamps=signedCertTimestamps,
             fallbackSCSV=fallbackSCSV,
             ocspResponse=ocspResponse)
@@ -1223,7 +1216,7 @@
                              settings, reqCAs, reqCertTypes,
                              tacks, activationFlags, 
                              nextProtos, anon,
-                             tlsIntolerant, signedCertTimestamps, fallbackSCSV,
+                             signedCertTimestamps, fallbackSCSV,
                              ocspResponse):
 
         self._handshakeStart(client=False)
@@ -1261,7 +1254,7 @@
         # Handle ClientHello and resumption
         for result in self._serverGetClientHello(settings, certChain,\
                                             verifierDB, sessionCache,
-                                            anon, tlsIntolerant, fallbackSCSV):
+                                            anon, fallbackSCSV):
             if result in (0,1): yield result
             elif result == None:
                 self._handshakeDone(resumed=True)                
@@ -1376,7 +1369,7 @@
 
 
     def _serverGetClientHello(self, settings, certChain, verifierDB,
-                                sessionCache, anon, tlsIntolerant, fallbackSCSV):
+                                sessionCache, anon, fallbackSCSV):
         #Initialize acceptable cipher suites
         cipherSuites = []
         if verifierDB:
@@ -1413,11 +1406,21 @@
                 yield result
 
         #If simulating TLS intolerance, reject certain TLS versions.
-        elif (tlsIntolerant is not None and
-            clientHello.client_version >= tlsIntolerant):
-            for result in self._sendError(\
+        elif (settings.tlsIntolerant is not None and
+              clientHello.client_version >= settings.tlsIntolerant):
+            if settings.tlsIntoleranceType == "alert":
+                for result in self._sendError(\
                     AlertDescription.handshake_failure):
-                yield result
+                    yield result
+            elif settings.tlsIntoleranceType == "close":
+                self._abruptClose()
+                raise TLSUnsupportedError("Simulating version intolerance")
+            elif settings.tlsIntoleranceType == "reset":
+                self._abruptClose(reset=True)
+                raise TLSUnsupportedError("Simulating version intolerance")
+            else:
+                raise ValueError("Unknown intolerance type: '%s'" %
+                                 settings.tlsIntoleranceType)
 
         #If client's version is too high, propose my highest version
         elif clientHello.client_version > settings.maxVersion:
diff --git a/tlslite/tlsrecordlayer.py b/tlslite/tlsrecordlayer.py
index 370dc9a..23c2a2f 100644
--- a/tlslite/tlsrecordlayer.py
+++ b/tlslite/tlsrecordlayer.py
@@ -19,6 +19,7 @@
 from .utils.cryptomath import getRandomBytes
 
 import socket
+import struct
 import errno
 import traceback
 
@@ -523,6 +524,13 @@
         self._shutdown(False)
         raise TLSLocalAlert(alert, errorStr)
 
+    def _abruptClose(self, reset=False):
+        if reset:
+            #Set an SO_LINGER timeout of 0 to send a TCP RST.
+            self.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
+                            struct.pack('ii', 1, 0))
+        self._shutdown(False)
+
     def _sendMsgs(self, msgs):
         randomizeFirstBlock = True
         for msg in msgs: