Add DHE_RSA support to tlslite.

Needed for our test server to be able to False Start.

BUG=354132

Review URL: https://codereview.chromium.org/212883008

git-svn-id: http://src.chromium.org/svn/trunk/src/third_party/tlslite@263169 4ff67af0-8c30-449e-8e8b-ad334ec8d88c
diff --git a/README.chromium b/README.chromium
index c88d58e..28dded0 100644
--- a/README.chromium
+++ b/README.chromium
@@ -30,3 +30,4 @@
   client_cipher_preferences.patch.
 - patches/fix_test_file.patch: Fix #! line in random test file to appease our
   presubmit checks.
+- patches/dhe_rsa.patch: Implement DHE_RSA-based cipher suites.
diff --git a/patches/dhe_rsa.patch b/patches/dhe_rsa.patch
new file mode 100644
index 0000000..1466851
--- /dev/null
+++ b/patches/dhe_rsa.patch
@@ -0,0 +1,417 @@
+diff --git a/third_party/tlslite/tlslite/constants.py b/third_party/tlslite/tlslite/constants.py
+index 52c20ac..feca423 100644
+--- a/third_party/tlslite/tlslite/constants.py
++++ b/third_party/tlslite/tlslite/constants.py
+@@ -143,6 +143,10 @@ class CipherSuite:
+     
+     TLS_RSA_WITH_RC4_128_MD5 = 0x0004
+ 
++    TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA = 0x0016
++    TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033
++    TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039
++
+     TLS_DH_ANON_WITH_AES_128_CBC_SHA = 0x0034
+     TLS_DH_ANON_WITH_AES_256_CBC_SHA = 0x003A
+ 
+@@ -150,17 +154,20 @@ class CipherSuite:
+     tripleDESSuites.append(TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA)
+     tripleDESSuites.append(TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA)
+     tripleDESSuites.append(TLS_RSA_WITH_3DES_EDE_CBC_SHA)
++    tripleDESSuites.append(TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA)
+ 
+     aes128Suites = []
+     aes128Suites.append(TLS_SRP_SHA_WITH_AES_128_CBC_SHA)
+     aes128Suites.append(TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA)
+     aes128Suites.append(TLS_RSA_WITH_AES_128_CBC_SHA)
++    aes128Suites.append(TLS_DHE_RSA_WITH_AES_128_CBC_SHA)
+     aes128Suites.append(TLS_DH_ANON_WITH_AES_128_CBC_SHA)
+ 
+     aes256Suites = []
+     aes256Suites.append(TLS_SRP_SHA_WITH_AES_256_CBC_SHA)
+     aes256Suites.append(TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA)
+     aes256Suites.append(TLS_RSA_WITH_AES_256_CBC_SHA)
++    aes256Suites.append(TLS_DHE_RSA_WITH_AES_256_CBC_SHA)
+     aes256Suites.append(TLS_DH_ANON_WITH_AES_256_CBC_SHA)
+ 
+     rc4Suites = []
+@@ -178,6 +185,9 @@ class CipherSuite:
+     shaSuites.append(TLS_RSA_WITH_AES_128_CBC_SHA)
+     shaSuites.append(TLS_RSA_WITH_AES_256_CBC_SHA)
+     shaSuites.append(TLS_RSA_WITH_RC4_128_SHA)
++    shaSuites.append(TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA)
++    shaSuites.append(TLS_DHE_RSA_WITH_AES_128_CBC_SHA)
++    shaSuites.append(TLS_DHE_RSA_WITH_AES_256_CBC_SHA)
+     shaSuites.append(TLS_DH_ANON_WITH_AES_128_CBC_SHA)
+     shaSuites.append(TLS_DH_ANON_WITH_AES_256_CBC_SHA)
+     
+@@ -188,6 +198,7 @@ class CipherSuite:
+     def _filterSuites(suites, settings):
+         macNames = settings.macNames
+         cipherNames = settings.cipherNames
++        keyExchangeNames = settings.keyExchangeNames
+         macSuites = []
+         if "sha" in macNames:
+             macSuites += CipherSuite.shaSuites
+@@ -204,7 +215,20 @@ class CipherSuite:
+         if "rc4" in cipherNames:
+             cipherSuites += CipherSuite.rc4Suites
+ 
+-        return [s for s in suites if s in macSuites and s in cipherSuites]
++        keyExchangeSuites = []
++        if "rsa" in keyExchangeNames:
++            keyExchangeSuites += CipherSuite.certSuites
++        if "dhe_rsa" in keyExchangeNames:
++            keyExchangeSuites += CipherSuite.dheCertSuites
++        if "srp_sha" in keyExchangeNames:
++            keyExchangeSuites += CipherSuite.srpSuites
++        if "srp_sha_rsa" in keyExchangeNames:
++            keyExchangeSuites += CipherSuite.srpCertSuites
++        if "dh_anon" in keyExchangeNames:
++            keyExchangeSuites += CipherSuite.anonSuites
++
++        return [s for s in suites if s in macSuites and
++                s in cipherSuites and s in keyExchangeSuites]
+ 
+     srpSuites = []
+     srpSuites.append(TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA)
+@@ -236,12 +260,22 @@ class CipherSuite:
+     certSuites.append(TLS_RSA_WITH_AES_256_CBC_SHA)
+     certSuites.append(TLS_RSA_WITH_RC4_128_SHA)
+     certSuites.append(TLS_RSA_WITH_RC4_128_MD5)
+-    certAllSuites = srpCertSuites + certSuites
+     
+     @staticmethod
+     def getCertSuites(settings):
+         return CipherSuite._filterSuites(CipherSuite.certSuites, settings)
+ 
++    dheCertSuites = []
++    dheCertSuites.append(TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA)
++    dheCertSuites.append(TLS_DHE_RSA_WITH_AES_128_CBC_SHA)
++    dheCertSuites.append(TLS_DHE_RSA_WITH_AES_256_CBC_SHA)
++
++    @staticmethod
++    def getDheCertSuites(settings):
++        return CipherSuite._filterSuites(CipherSuite.dheCertSuites, settings)
++
++    certAllSuites = srpCertSuites + certSuites + dheCertSuites
++
+     anonSuites = []
+     anonSuites.append(TLS_DH_ANON_WITH_AES_128_CBC_SHA)
+     anonSuites.append(TLS_DH_ANON_WITH_AES_256_CBC_SHA)
+@@ -250,6 +284,8 @@ class CipherSuite:
+     def getAnonSuites(settings):
+         return CipherSuite._filterSuites(CipherSuite.anonSuites, settings)
+ 
++    dhAllSuites = dheCertSuites + anonSuites
++
+     @staticmethod
+     def canonicalCipherName(ciphersuite):
+         "Return the canonical name of the cipher whose number is provided."
+diff --git a/third_party/tlslite/tlslite/handshakesettings.py b/third_party/tlslite/tlslite/handshakesettings.py
+index 7a38ee2..e0bc0e6 100644
+--- a/third_party/tlslite/tlslite/handshakesettings.py
++++ b/third_party/tlslite/tlslite/handshakesettings.py
+@@ -13,7 +13,9 @@ from .utils import cipherfactory
+ # RC4 is preferred as faster in Python, works in SSL3, and immune to CBC
+ # issues such as timing attacks
+ CIPHER_NAMES = ["rc4", "aes256", "aes128", "3des"]
+-MAC_NAMES = ["sha"] # "md5" is allowed
++MAC_NAMES = ["sha"] # Don't allow "md5" by default.
++ALL_MAC_NAMES = ["sha", "md5"]
++KEY_EXCHANGE_NAMES = ["rsa", "dhe_rsa", "srp_sha", "srp_sha_rsa", "dh_anon"]
+ CIPHER_IMPLEMENTATIONS = ["openssl", "pycrypto", "python"]
+ CERTIFICATE_TYPES = ["x509"]
+ 
+@@ -102,6 +104,7 @@ class HandshakeSettings(object):
+         self.maxKeySize = 8193
+         self.cipherNames = CIPHER_NAMES
+         self.macNames = MAC_NAMES
++        self.keyExchangeNames = KEY_EXCHANGE_NAMES
+         self.cipherImplementations = CIPHER_IMPLEMENTATIONS
+         self.certificateTypes = CERTIFICATE_TYPES
+         self.minVersion = (3,0)
+@@ -116,6 +119,7 @@ class HandshakeSettings(object):
+         other.maxKeySize = self.maxKeySize
+         other.cipherNames = self.cipherNames
+         other.macNames = self.macNames
++        other.keyExchangeNames = self.keyExchangeNames
+         other.cipherImplementations = self.cipherImplementations
+         other.certificateTypes = self.certificateTypes
+         other.minVersion = self.minVersion
+@@ -148,6 +152,12 @@ class HandshakeSettings(object):
+         for s in other.cipherNames:
+             if s not in CIPHER_NAMES:
+                 raise ValueError("Unknown cipher name: '%s'" % s)
++        for s in other.macNames:
++            if s not in ALL_MAC_NAMES:
++                raise ValueError("Unknown MAC name: '%s'" % s)
++        for s in other.keyExchangeNames:
++            if s not in KEY_EXCHANGE_NAMES:
++                raise ValueError("Unknown key exchange name: '%s'" % s)
+         for s in other.cipherImplementations:
+             if s not in CIPHER_IMPLEMENTATIONS:
+                 raise ValueError("Unknown cipher implementation: '%s'" % s)
+diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
+index 532d86b..550b387 100644
+--- a/third_party/tlslite/tlslite/messages.py
++++ b/third_party/tlslite/tlslite/messages.py
+@@ -533,31 +533,31 @@ class ServerKeyExchange(HandshakeMsg):
+         p.stopLengthCheck()
+         return self
+ 
+-    def write(self):
++    def write_params(self):
+         w = Writer()
+         if self.cipherSuite in CipherSuite.srpAllSuites:
+             w.addVarSeq(numberToByteArray(self.srp_N), 1, 2)
+             w.addVarSeq(numberToByteArray(self.srp_g), 1, 2)
+             w.addVarSeq(self.srp_s, 1, 1)
+             w.addVarSeq(numberToByteArray(self.srp_B), 1, 2)
+-            if self.cipherSuite in CipherSuite.srpCertSuites:
+-                w.addVarSeq(self.signature, 1, 2)
+-        elif self.cipherSuite in CipherSuite.anonSuites:
++        elif self.cipherSuite in CipherSuite.dhAllSuites:
+             w.addVarSeq(numberToByteArray(self.dh_p), 1, 2)
+             w.addVarSeq(numberToByteArray(self.dh_g), 1, 2)
+             w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2)
+-            if self.cipherSuite in []: # TODO support for signed_params
+-                w.addVarSeq(self.signature, 1, 2)
++        else:
++            assert(False)
++        return w.bytes
++
++    def write(self):
++        w = Writer()
++        w.bytes += self.write_params()
++        if self.cipherSuite in CipherSuite.certAllSuites:
++            w.addVarSeq(self.signature, 1, 2)
+         return self.postWrite(w)
+ 
+     def hash(self, clientRandom, serverRandom):
+-        oldCipherSuite = self.cipherSuite
+-        self.cipherSuite = None
+-        try:
+-            bytes = clientRandom + serverRandom + self.write()[4:]
+-            return MD5(bytes) + SHA1(bytes)
+-        finally:
+-            self.cipherSuite = oldCipherSuite
++        bytes = clientRandom + serverRandom + self.write_params()
++        return MD5(bytes) + SHA1(bytes)
+ 
+ class ServerHelloDone(HandshakeMsg):
+     def __init__(self):
+@@ -607,7 +607,7 @@ class ClientKeyExchange(HandshakeMsg):
+                     p.getFixBytes(len(p.bytes)-p.index)
+             else:
+                 raise AssertionError()
+-        elif self.cipherSuite in CipherSuite.anonSuites:
++        elif self.cipherSuite in CipherSuite.dhAllSuites:
+             self.dh_Yc = bytesToNumber(p.getVarBytes(2))            
+         else:
+             raise AssertionError()
+diff --git a/third_party/tlslite/tlslite/tlsconnection.py b/third_party/tlslite/tlslite/tlsconnection.py
+index 20cd85b..e6f7820 100644
+--- a/third_party/tlslite/tlslite/tlsconnection.py
++++ b/third_party/tlslite/tlslite/tlsconnection.py
+@@ -23,6 +23,103 @@ from .mathtls import *
+ from .handshakesettings import HandshakeSettings
+ from .utils.tackwrapper import *
+ 
++class KeyExchange(object):
++    def __init__(self, cipherSuite, clientHello, serverHello, privateKey):
++        """
++        Initializes the KeyExchange. privateKey is the signing private key.
++        """
++        self.cipherSuite = cipherSuite
++        self.clientHello = clientHello
++        self.serverHello = serverHello
++        self.privateKey = privateKey
++
++    def makeServerKeyExchange():
++        """
++        Returns a ServerKeyExchange object for the server's initial leg in the
++        handshake. If the key exchange method does not send ServerKeyExchange
++        (e.g. RSA), it returns None.
++        """
++        raise NotImplementedError()
++
++    def processClientKeyExchange(clientKeyExchange):
++        """
++        Processes the client's ClientKeyExchange message and returns the
++        premaster secret. Raises TLSLocalAlert on error.
++        """
++        raise NotImplementedError()
++
++class RSAKeyExchange(KeyExchange):
++    def makeServerKeyExchange(self):
++        return None
++
++    def processClientKeyExchange(self, clientKeyExchange):
++        premasterSecret = self.privateKey.decrypt(\
++            clientKeyExchange.encryptedPreMasterSecret)
++
++        # On decryption failure randomize premaster secret to avoid
++        # Bleichenbacher's "million message" attack
++        randomPreMasterSecret = getRandomBytes(48)
++        if not premasterSecret:
++            premasterSecret = randomPreMasterSecret
++        elif len(premasterSecret)!=48:
++            premasterSecret = randomPreMasterSecret
++        else:
++            versionCheck = (premasterSecret[0], premasterSecret[1])
++            if versionCheck != self.clientHello.client_version:
++                #Tolerate buggy IE clients
++                if versionCheck != self.serverHello.server_version:
++                    premasterSecret = randomPreMasterSecret
++        return premasterSecret
++
++def _hexStringToNumber(s):
++    s = s.replace(" ", "").replace("\n", "")
++    if len(s) % 2 != 0:
++        raise ValueError("Length is not even")
++    return bytesToNumber(bytearray(s.decode("hex")))
++
++class DHE_RSAKeyExchange(KeyExchange):
++    # 2048-bit MODP Group (RFC 3526, Section 3)
++    dh_p = _hexStringToNumber("""
++FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
++29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
++EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
++E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
++EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D
++C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F
++83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D
++670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B
++E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9
++DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510
++15728E5A 8AACAA68 FFFFFFFF FFFFFFFF""")
++    dh_g = 2
++
++    # RFC 3526, Section 8.
++    strength = 160
++
++    def makeServerKeyExchange(self):
++        # Per RFC 3526, Section 1, the exponent should have double the entropy
++        # of the strength of the curve.
++        self.dh_Xs = bytesToNumber(getRandomBytes(self.strength * 2 / 8))
++        dh_Ys = powMod(self.dh_g, self.dh_Xs, self.dh_p)
++
++        serverKeyExchange = ServerKeyExchange(self.cipherSuite)
++        serverKeyExchange.createDH(self.dh_p, self.dh_g, dh_Ys)
++        serverKeyExchange.signature = self.privateKey.sign(
++            serverKeyExchange.hash(self.clientHello.random,
++                                   self.serverHello.random))
++        return serverKeyExchange
++
++    def processClientKeyExchange(self, clientKeyExchange):
++        dh_Yc = clientKeyExchange.dh_Yc
++
++        # First half of RFC 2631, Section 2.1.5. Validate the client's public
++        # key.
++        if not 2 <= dh_Yc <= self.dh_p - 1:
++            raise TLSLocalAlert(AlertDescription.illegal_parameter,
++                                "Invalid dh_Yc value")
++
++        S = powMod(dh_Yc, self.dh_Xs, self.dh_p)
++        return numberToByteArray(S)
+ 
+ class TLSConnection(TLSRecordLayer):
+     """
+@@ -500,6 +597,8 @@ class TLSConnection(TLSRecordLayer):
+             cipherSuites += CipherSuite.getSrpAllSuites(settings)
+         elif certParams:
+             cipherSuites += CipherSuite.getCertSuites(settings)
++            # TODO: Client DHE_RSA not supported.
++            # cipherSuites += CipherSuite.getDheCertSuites(settings)
+         elif anonParams:
+             cipherSuites += CipherSuite.getAnonSuites(settings)
+         else:
+@@ -1204,10 +1303,23 @@ class TLSConnection(TLSRecordLayer):
+                 else: break
+             premasterSecret = result
+ 
+-        # Perform the RSA key exchange
+-        elif cipherSuite in CipherSuite.certSuites:
++        # Perform the RSA or DHE_RSA key exchange
++        elif (cipherSuite in CipherSuite.certSuites or
++              cipherSuite in CipherSuite.dheCertSuites):
++            if cipherSuite in CipherSuite.certSuites:
++                keyExchange = RSAKeyExchange(cipherSuite,
++                                             clientHello,
++                                             serverHello,
++                                             privateKey)
++            elif cipherSuite in CipherSuite.dheCertSuites:
++                keyExchange = DHE_RSAKeyExchange(cipherSuite,
++                                                 clientHello,
++                                                 serverHello,
++                                                 privateKey)
++            else:
++                assert(False)
+             for result in self._serverCertKeyExchange(clientHello, serverHello, 
+-                                        certChain, privateKey,
++                                        certChain, keyExchange,
+                                         reqCert, reqCAs, cipherSuite,
+                                         settings, ocspResponse):
+                 if result in (0,1): yield result
+@@ -1268,6 +1380,7 @@ class TLSConnection(TLSRecordLayer):
+             cipherSuites += CipherSuite.getSrpSuites(settings)
+         elif certChain:
+             cipherSuites += CipherSuite.getCertSuites(settings)
++            cipherSuites += CipherSuite.getDheCertSuites(settings)
+         elif anon:
+             cipherSuites += CipherSuite.getAnonSuites(settings)
+         else:
+@@ -1483,11 +1596,11 @@ class TLSConnection(TLSRecordLayer):
+ 
+ 
+     def _serverCertKeyExchange(self, clientHello, serverHello, 
+-                                serverCertChain, privateKey,
++                                serverCertChain, keyExchange,
+                                 reqCert, reqCAs, cipherSuite,
+                                 settings, ocspResponse):
+-        #Send ServerHello, Certificate[, CertificateRequest],
+-        #ServerHelloDone
++        #Send ServerHello, Certificate[, ServerKeyExchange]
++        #[, CertificateRequest], ServerHelloDone
+         msgs = []
+ 
+         # If we verify a client cert chain, return it
+@@ -1497,6 +1610,9 @@ class TLSConnection(TLSRecordLayer):
+         msgs.append(Certificate(CertificateType.x509).create(serverCertChain))
+         if serverHello.status_request:
+             msgs.append(CertificateStatus().create(ocspResponse))
++        serverKeyExchange = keyExchange.makeServerKeyExchange()
++        if serverKeyExchange is not None:
++            msgs.append(serverKeyExchange)
+         if reqCert and reqCAs:
+             msgs.append(CertificateRequest().create(\
+                 [ClientCertificateType.rsa_sign], reqCAs))
+@@ -1555,21 +1671,13 @@ class TLSConnection(TLSRecordLayer):
+             else: break
+         clientKeyExchange = result
+ 
+-        #Decrypt ClientKeyExchange
+-        premasterSecret = privateKey.decrypt(\
+-            clientKeyExchange.encryptedPreMasterSecret)
+-
+-        # On decryption failure randomize premaster secret to avoid
+-        # Bleichenbacher's "million message" attack
+-        randomPreMasterSecret = getRandomBytes(48)
+-        versionCheck = (premasterSecret[0], premasterSecret[1])
+-        if not premasterSecret:
+-            premasterSecret = randomPreMasterSecret
+-        elif len(premasterSecret)!=48:
+-            premasterSecret = randomPreMasterSecret
+-        elif versionCheck != clientHello.client_version:
+-            if versionCheck != self.version: #Tolerate buggy IE clients
+-                premasterSecret = randomPreMasterSecret
++        #Process ClientKeyExchange
++        try:
++            premasterSecret = \
++                keyExchange.processClientKeyExchange(clientKeyExchange)
++        except TLSLocalAlert, alert:
++            for result in self._sendError(alert.description, alert.message):
++                yield result
+ 
+         #Get and check CertificateVerify, if relevant
+         if clientCertChain:
diff --git a/tlslite/constants.py b/tlslite/constants.py
index 52c20ac..feca423 100644
--- a/tlslite/constants.py
+++ b/tlslite/constants.py
@@ -143,6 +143,10 @@
     
     TLS_RSA_WITH_RC4_128_MD5 = 0x0004
 
+    TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA = 0x0016
+    TLS_DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033
+    TLS_DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039
+
     TLS_DH_ANON_WITH_AES_128_CBC_SHA = 0x0034
     TLS_DH_ANON_WITH_AES_256_CBC_SHA = 0x003A
 
@@ -150,17 +154,20 @@
     tripleDESSuites.append(TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA)
     tripleDESSuites.append(TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA)
     tripleDESSuites.append(TLS_RSA_WITH_3DES_EDE_CBC_SHA)
+    tripleDESSuites.append(TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA)
 
     aes128Suites = []
     aes128Suites.append(TLS_SRP_SHA_WITH_AES_128_CBC_SHA)
     aes128Suites.append(TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA)
     aes128Suites.append(TLS_RSA_WITH_AES_128_CBC_SHA)
+    aes128Suites.append(TLS_DHE_RSA_WITH_AES_128_CBC_SHA)
     aes128Suites.append(TLS_DH_ANON_WITH_AES_128_CBC_SHA)
 
     aes256Suites = []
     aes256Suites.append(TLS_SRP_SHA_WITH_AES_256_CBC_SHA)
     aes256Suites.append(TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA)
     aes256Suites.append(TLS_RSA_WITH_AES_256_CBC_SHA)
+    aes256Suites.append(TLS_DHE_RSA_WITH_AES_256_CBC_SHA)
     aes256Suites.append(TLS_DH_ANON_WITH_AES_256_CBC_SHA)
 
     rc4Suites = []
@@ -178,6 +185,9 @@
     shaSuites.append(TLS_RSA_WITH_AES_128_CBC_SHA)
     shaSuites.append(TLS_RSA_WITH_AES_256_CBC_SHA)
     shaSuites.append(TLS_RSA_WITH_RC4_128_SHA)
+    shaSuites.append(TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA)
+    shaSuites.append(TLS_DHE_RSA_WITH_AES_128_CBC_SHA)
+    shaSuites.append(TLS_DHE_RSA_WITH_AES_256_CBC_SHA)
     shaSuites.append(TLS_DH_ANON_WITH_AES_128_CBC_SHA)
     shaSuites.append(TLS_DH_ANON_WITH_AES_256_CBC_SHA)
     
@@ -188,6 +198,7 @@
     def _filterSuites(suites, settings):
         macNames = settings.macNames
         cipherNames = settings.cipherNames
+        keyExchangeNames = settings.keyExchangeNames
         macSuites = []
         if "sha" in macNames:
             macSuites += CipherSuite.shaSuites
@@ -204,7 +215,20 @@
         if "rc4" in cipherNames:
             cipherSuites += CipherSuite.rc4Suites
 
-        return [s for s in suites if s in macSuites and s in cipherSuites]
+        keyExchangeSuites = []
+        if "rsa" in keyExchangeNames:
+            keyExchangeSuites += CipherSuite.certSuites
+        if "dhe_rsa" in keyExchangeNames:
+            keyExchangeSuites += CipherSuite.dheCertSuites
+        if "srp_sha" in keyExchangeNames:
+            keyExchangeSuites += CipherSuite.srpSuites
+        if "srp_sha_rsa" in keyExchangeNames:
+            keyExchangeSuites += CipherSuite.srpCertSuites
+        if "dh_anon" in keyExchangeNames:
+            keyExchangeSuites += CipherSuite.anonSuites
+
+        return [s for s in suites if s in macSuites and
+                s in cipherSuites and s in keyExchangeSuites]
 
     srpSuites = []
     srpSuites.append(TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA)
@@ -236,12 +260,22 @@
     certSuites.append(TLS_RSA_WITH_AES_256_CBC_SHA)
     certSuites.append(TLS_RSA_WITH_RC4_128_SHA)
     certSuites.append(TLS_RSA_WITH_RC4_128_MD5)
-    certAllSuites = srpCertSuites + certSuites
     
     @staticmethod
     def getCertSuites(settings):
         return CipherSuite._filterSuites(CipherSuite.certSuites, settings)
 
+    dheCertSuites = []
+    dheCertSuites.append(TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA)
+    dheCertSuites.append(TLS_DHE_RSA_WITH_AES_128_CBC_SHA)
+    dheCertSuites.append(TLS_DHE_RSA_WITH_AES_256_CBC_SHA)
+
+    @staticmethod
+    def getDheCertSuites(settings):
+        return CipherSuite._filterSuites(CipherSuite.dheCertSuites, settings)
+
+    certAllSuites = srpCertSuites + certSuites + dheCertSuites
+
     anonSuites = []
     anonSuites.append(TLS_DH_ANON_WITH_AES_128_CBC_SHA)
     anonSuites.append(TLS_DH_ANON_WITH_AES_256_CBC_SHA)
@@ -250,6 +284,8 @@
     def getAnonSuites(settings):
         return CipherSuite._filterSuites(CipherSuite.anonSuites, settings)
 
+    dhAllSuites = dheCertSuites + anonSuites
+
     @staticmethod
     def canonicalCipherName(ciphersuite):
         "Return the canonical name of the cipher whose number is provided."
diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py
index 7a38ee2..e0bc0e6 100644
--- a/tlslite/handshakesettings.py
+++ b/tlslite/handshakesettings.py
@@ -13,7 +13,9 @@
 # RC4 is preferred as faster in Python, works in SSL3, and immune to CBC
 # issues such as timing attacks
 CIPHER_NAMES = ["rc4", "aes256", "aes128", "3des"]
-MAC_NAMES = ["sha"] # "md5" is allowed
+MAC_NAMES = ["sha"] # Don't allow "md5" by default.
+ALL_MAC_NAMES = ["sha", "md5"]
+KEY_EXCHANGE_NAMES = ["rsa", "dhe_rsa", "srp_sha", "srp_sha_rsa", "dh_anon"]
 CIPHER_IMPLEMENTATIONS = ["openssl", "pycrypto", "python"]
 CERTIFICATE_TYPES = ["x509"]
 
@@ -102,6 +104,7 @@
         self.maxKeySize = 8193
         self.cipherNames = CIPHER_NAMES
         self.macNames = MAC_NAMES
+        self.keyExchangeNames = KEY_EXCHANGE_NAMES
         self.cipherImplementations = CIPHER_IMPLEMENTATIONS
         self.certificateTypes = CERTIFICATE_TYPES
         self.minVersion = (3,0)
@@ -116,6 +119,7 @@
         other.maxKeySize = self.maxKeySize
         other.cipherNames = self.cipherNames
         other.macNames = self.macNames
+        other.keyExchangeNames = self.keyExchangeNames
         other.cipherImplementations = self.cipherImplementations
         other.certificateTypes = self.certificateTypes
         other.minVersion = self.minVersion
@@ -148,6 +152,12 @@
         for s in other.cipherNames:
             if s not in CIPHER_NAMES:
                 raise ValueError("Unknown cipher name: '%s'" % s)
+        for s in other.macNames:
+            if s not in ALL_MAC_NAMES:
+                raise ValueError("Unknown MAC name: '%s'" % s)
+        for s in other.keyExchangeNames:
+            if s not in KEY_EXCHANGE_NAMES:
+                raise ValueError("Unknown key exchange name: '%s'" % s)
         for s in other.cipherImplementations:
             if s not in CIPHER_IMPLEMENTATIONS:
                 raise ValueError("Unknown cipher implementation: '%s'" % s)
diff --git a/tlslite/messages.py b/tlslite/messages.py
index 532d86b..550b387 100644
--- a/tlslite/messages.py
+++ b/tlslite/messages.py
@@ -533,31 +533,31 @@
         p.stopLengthCheck()
         return self
 
-    def write(self):
+    def write_params(self):
         w = Writer()
         if self.cipherSuite in CipherSuite.srpAllSuites:
             w.addVarSeq(numberToByteArray(self.srp_N), 1, 2)
             w.addVarSeq(numberToByteArray(self.srp_g), 1, 2)
             w.addVarSeq(self.srp_s, 1, 1)
             w.addVarSeq(numberToByteArray(self.srp_B), 1, 2)
-            if self.cipherSuite in CipherSuite.srpCertSuites:
-                w.addVarSeq(self.signature, 1, 2)
-        elif self.cipherSuite in CipherSuite.anonSuites:
+        elif self.cipherSuite in CipherSuite.dhAllSuites:
             w.addVarSeq(numberToByteArray(self.dh_p), 1, 2)
             w.addVarSeq(numberToByteArray(self.dh_g), 1, 2)
             w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2)
-            if self.cipherSuite in []: # TODO support for signed_params
-                w.addVarSeq(self.signature, 1, 2)
+        else:
+            assert(False)
+        return w.bytes
+
+    def write(self):
+        w = Writer()
+        w.bytes += self.write_params()
+        if self.cipherSuite in CipherSuite.certAllSuites:
+            w.addVarSeq(self.signature, 1, 2)
         return self.postWrite(w)
 
     def hash(self, clientRandom, serverRandom):
-        oldCipherSuite = self.cipherSuite
-        self.cipherSuite = None
-        try:
-            bytes = clientRandom + serverRandom + self.write()[4:]
-            return MD5(bytes) + SHA1(bytes)
-        finally:
-            self.cipherSuite = oldCipherSuite
+        bytes = clientRandom + serverRandom + self.write_params()
+        return MD5(bytes) + SHA1(bytes)
 
 class ServerHelloDone(HandshakeMsg):
     def __init__(self):
@@ -607,7 +607,7 @@
                     p.getFixBytes(len(p.bytes)-p.index)
             else:
                 raise AssertionError()
-        elif self.cipherSuite in CipherSuite.anonSuites:
+        elif self.cipherSuite in CipherSuite.dhAllSuites:
             self.dh_Yc = bytesToNumber(p.getVarBytes(2))            
         else:
             raise AssertionError()
diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py
index 20cd85b..e6f7820 100644
--- a/tlslite/tlsconnection.py
+++ b/tlslite/tlsconnection.py
@@ -23,6 +23,103 @@
 from .handshakesettings import HandshakeSettings
 from .utils.tackwrapper import *
 
+class KeyExchange(object):
+    def __init__(self, cipherSuite, clientHello, serverHello, privateKey):
+        """
+        Initializes the KeyExchange. privateKey is the signing private key.
+        """
+        self.cipherSuite = cipherSuite
+        self.clientHello = clientHello
+        self.serverHello = serverHello
+        self.privateKey = privateKey
+
+    def makeServerKeyExchange():
+        """
+        Returns a ServerKeyExchange object for the server's initial leg in the
+        handshake. If the key exchange method does not send ServerKeyExchange
+        (e.g. RSA), it returns None.
+        """
+        raise NotImplementedError()
+
+    def processClientKeyExchange(clientKeyExchange):
+        """
+        Processes the client's ClientKeyExchange message and returns the
+        premaster secret. Raises TLSLocalAlert on error.
+        """
+        raise NotImplementedError()
+
+class RSAKeyExchange(KeyExchange):
+    def makeServerKeyExchange(self):
+        return None
+
+    def processClientKeyExchange(self, clientKeyExchange):
+        premasterSecret = self.privateKey.decrypt(\
+            clientKeyExchange.encryptedPreMasterSecret)
+
+        # On decryption failure randomize premaster secret to avoid
+        # Bleichenbacher's "million message" attack
+        randomPreMasterSecret = getRandomBytes(48)
+        if not premasterSecret:
+            premasterSecret = randomPreMasterSecret
+        elif len(premasterSecret)!=48:
+            premasterSecret = randomPreMasterSecret
+        else:
+            versionCheck = (premasterSecret[0], premasterSecret[1])
+            if versionCheck != self.clientHello.client_version:
+                #Tolerate buggy IE clients
+                if versionCheck != self.serverHello.server_version:
+                    premasterSecret = randomPreMasterSecret
+        return premasterSecret
+
+def _hexStringToNumber(s):
+    s = s.replace(" ", "").replace("\n", "")
+    if len(s) % 2 != 0:
+        raise ValueError("Length is not even")
+    return bytesToNumber(bytearray(s.decode("hex")))
+
+class DHE_RSAKeyExchange(KeyExchange):
+    # 2048-bit MODP Group (RFC 3526, Section 3)
+    dh_p = _hexStringToNumber("""
+FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
+29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
+EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
+E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
+EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D
+C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F
+83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D
+670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B
+E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9
+DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510
+15728E5A 8AACAA68 FFFFFFFF FFFFFFFF""")
+    dh_g = 2
+
+    # RFC 3526, Section 8.
+    strength = 160
+
+    def makeServerKeyExchange(self):
+        # Per RFC 3526, Section 1, the exponent should have double the entropy
+        # of the strength of the curve.
+        self.dh_Xs = bytesToNumber(getRandomBytes(self.strength * 2 / 8))
+        dh_Ys = powMod(self.dh_g, self.dh_Xs, self.dh_p)
+
+        serverKeyExchange = ServerKeyExchange(self.cipherSuite)
+        serverKeyExchange.createDH(self.dh_p, self.dh_g, dh_Ys)
+        serverKeyExchange.signature = self.privateKey.sign(
+            serverKeyExchange.hash(self.clientHello.random,
+                                   self.serverHello.random))
+        return serverKeyExchange
+
+    def processClientKeyExchange(self, clientKeyExchange):
+        dh_Yc = clientKeyExchange.dh_Yc
+
+        # First half of RFC 2631, Section 2.1.5. Validate the client's public
+        # key.
+        if not 2 <= dh_Yc <= self.dh_p - 1:
+            raise TLSLocalAlert(AlertDescription.illegal_parameter,
+                                "Invalid dh_Yc value")
+
+        S = powMod(dh_Yc, self.dh_Xs, self.dh_p)
+        return numberToByteArray(S)
 
 class TLSConnection(TLSRecordLayer):
     """
@@ -500,6 +597,8 @@
             cipherSuites += CipherSuite.getSrpAllSuites(settings)
         elif certParams:
             cipherSuites += CipherSuite.getCertSuites(settings)
+            # TODO: Client DHE_RSA not supported.
+            # cipherSuites += CipherSuite.getDheCertSuites(settings)
         elif anonParams:
             cipherSuites += CipherSuite.getAnonSuites(settings)
         else:
@@ -1204,10 +1303,23 @@
                 else: break
             premasterSecret = result
 
-        # Perform the RSA key exchange
-        elif cipherSuite in CipherSuite.certSuites:
+        # Perform the RSA or DHE_RSA key exchange
+        elif (cipherSuite in CipherSuite.certSuites or
+              cipherSuite in CipherSuite.dheCertSuites):
+            if cipherSuite in CipherSuite.certSuites:
+                keyExchange = RSAKeyExchange(cipherSuite,
+                                             clientHello,
+                                             serverHello,
+                                             privateKey)
+            elif cipherSuite in CipherSuite.dheCertSuites:
+                keyExchange = DHE_RSAKeyExchange(cipherSuite,
+                                                 clientHello,
+                                                 serverHello,
+                                                 privateKey)
+            else:
+                assert(False)
             for result in self._serverCertKeyExchange(clientHello, serverHello, 
-                                        certChain, privateKey,
+                                        certChain, keyExchange,
                                         reqCert, reqCAs, cipherSuite,
                                         settings, ocspResponse):
                 if result in (0,1): yield result
@@ -1268,6 +1380,7 @@
             cipherSuites += CipherSuite.getSrpSuites(settings)
         elif certChain:
             cipherSuites += CipherSuite.getCertSuites(settings)
+            cipherSuites += CipherSuite.getDheCertSuites(settings)
         elif anon:
             cipherSuites += CipherSuite.getAnonSuites(settings)
         else:
@@ -1483,11 +1596,11 @@
 
 
     def _serverCertKeyExchange(self, clientHello, serverHello, 
-                                serverCertChain, privateKey,
+                                serverCertChain, keyExchange,
                                 reqCert, reqCAs, cipherSuite,
                                 settings, ocspResponse):
-        #Send ServerHello, Certificate[, CertificateRequest],
-        #ServerHelloDone
+        #Send ServerHello, Certificate[, ServerKeyExchange]
+        #[, CertificateRequest], ServerHelloDone
         msgs = []
 
         # If we verify a client cert chain, return it
@@ -1497,6 +1610,9 @@
         msgs.append(Certificate(CertificateType.x509).create(serverCertChain))
         if serverHello.status_request:
             msgs.append(CertificateStatus().create(ocspResponse))
+        serverKeyExchange = keyExchange.makeServerKeyExchange()
+        if serverKeyExchange is not None:
+            msgs.append(serverKeyExchange)
         if reqCert and reqCAs:
             msgs.append(CertificateRequest().create(\
                 [ClientCertificateType.rsa_sign], reqCAs))
@@ -1555,21 +1671,13 @@
             else: break
         clientKeyExchange = result
 
-        #Decrypt ClientKeyExchange
-        premasterSecret = privateKey.decrypt(\
-            clientKeyExchange.encryptedPreMasterSecret)
-
-        # On decryption failure randomize premaster secret to avoid
-        # Bleichenbacher's "million message" attack
-        randomPreMasterSecret = getRandomBytes(48)
-        versionCheck = (premasterSecret[0], premasterSecret[1])
-        if not premasterSecret:
-            premasterSecret = randomPreMasterSecret
-        elif len(premasterSecret)!=48:
-            premasterSecret = randomPreMasterSecret
-        elif versionCheck != clientHello.client_version:
-            if versionCheck != self.version: #Tolerate buggy IE clients
-                premasterSecret = randomPreMasterSecret
+        #Process ClientKeyExchange
+        try:
+            premasterSecret = \
+                keyExchange.processClientKeyExchange(clientKeyExchange)
+        except TLSLocalAlert, alert:
+            for result in self._sendError(alert.description, alert.message):
+                yield result
 
         #Get and check CertificateVerify, if relevant
         if clientCertChain: