blob: 1ba92872aa39c5fbbf79f3c8df755a4bfe5513cf [file] [log] [blame]
diff --git a/third_party/tlslite/tlslite/constants.py b/third_party/tlslite/tlslite/constants.py
index 4165de0..6429c66 100644
--- a/third_party/tlslite/tlslite/constants.py
+++ b/third_party/tlslite/tlslite/constants.py
@@ -32,6 +32,7 @@ class HandshakeType:
client_key_exchange = 16
finished = 20
next_protocol = 67
+ encrypted_extensions = 203
class ContentType:
change_cipher_spec = 20
@@ -46,6 +47,7 @@ class ExtensionType: # RFC 6066 / 4366
cert_type = 9 # RFC 6091
tack = 0xF300
supports_npn = 13172
+ channel_id = 30032
class NameType:
host_name = 0
diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
index 2b3e518..4fa9d96 100644
--- a/third_party/tlslite/tlslite/messages.py
+++ b/third_party/tlslite/tlslite/messages.py
@@ -113,6 +113,7 @@ class ClientHello(HandshakeMsg):
self.tack = False
self.supports_npn = False
self.server_name = bytearray(0)
+ self.channel_id = False
def create(self, version, random, session_id, cipher_suites,
certificate_types=None, srpUsername=None,
@@ -180,6 +181,8 @@ class ClientHello(HandshakeMsg):
if name_type == NameType.host_name:
self.server_name = hostNameBytes
break
+ elif extType == ExtensionType.channel_id:
+ self.channel_id = True
else:
_ = p.getFixBytes(extLength)
index2 = p.index
@@ -244,6 +247,7 @@ class ServerHello(HandshakeMsg):
self.tackExt = None
self.next_protos_advertised = None
self.next_protos = None
+ self.channel_id = False
def create(self, version, random, session_id, cipher_suite,
certificate_type, tackExt, next_protos_advertised):
@@ -330,6 +334,9 @@ class ServerHello(HandshakeMsg):
w2.add(ExtensionType.supports_npn, 2)
w2.add(len(encoded_next_protos_advertised), 2)
w2.addFixSeq(encoded_next_protos_advertised, 1)
+ if self.channel_id:
+ w2.add(ExtensionType.channel_id, 2)
+ w2.add(0, 2)
if len(w2.bytes):
w.add(len(w2.bytes), 2)
w.bytes += w2.bytes
@@ -665,6 +672,28 @@ class Finished(HandshakeMsg):
w.addFixSeq(self.verify_data, 1)
return self.postWrite(w)
+class EncryptedExtensions(HandshakeMsg):
+ def __init__(self):
+ self.channel_id_key = None
+ self.channel_id_proof = None
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ soFar = 0
+ while soFar != p.lengthCheck:
+ extType = p.get(2)
+ extLength = p.get(2)
+ if extType == ExtensionType.channel_id:
+ if extLength != 32*4:
+ raise SyntaxError()
+ self.channel_id_key = p.getFixBytes(64)
+ self.channel_id_proof = p.getFixBytes(64)
+ else:
+ p.getFixBytes(extLength)
+ soFar += 4 + extLength
+ p.stopLengthCheck()
+ return self
+
class ApplicationData(object):
def __init__(self):
self.contentType = ContentType.application_data
diff --git a/third_party/tlslite/tlslite/tlsconnection.py b/third_party/tlslite/tlslite/tlsconnection.py
index 0e78753..b0400f8 100644
--- a/third_party/tlslite/tlslite/tlsconnection.py
+++ b/third_party/tlslite/tlslite/tlsconnection.py
@@ -1158,6 +1158,7 @@ class TLSConnection(TLSRecordLayer):
serverHello.create(self.version, getRandomBytes(32), sessionID, \
cipherSuite, CertificateType.x509, tackExt,
nextProtos)
+ serverHello.channel_id = clientHello.channel_id
# Perform the SRP key exchange
clientCertChain = None
@@ -1194,7 +1195,7 @@ class TLSConnection(TLSRecordLayer):
for result in self._serverFinished(premasterSecret,
clientHello.random, serverHello.random,
cipherSuite, settings.cipherImplementations,
- nextProtos):
+ nextProtos, clientHello.channel_id):
if result in (0,1): yield result
else: break
masterSecret = result
@@ -1614,7 +1615,8 @@ class TLSConnection(TLSRecordLayer):
def _serverFinished(self, premasterSecret, clientRandom, serverRandom,
- cipherSuite, cipherImplementations, nextProtos):
+ cipherSuite, cipherImplementations, nextProtos,
+ doingChannelID):
masterSecret = calcMasterSecret(self.version, premasterSecret,
clientRandom, serverRandom)
@@ -1625,7 +1627,8 @@ class TLSConnection(TLSRecordLayer):
#Exchange ChangeCipherSpec and Finished messages
for result in self._getFinished(masterSecret,
- expect_next_protocol=nextProtos is not None):
+ expect_next_protocol=nextProtos is not None,
+ expect_channel_id=doingChannelID):
yield result
for result in self._sendFinished(masterSecret):
@@ -1662,7 +1665,8 @@ class TLSConnection(TLSRecordLayer):
for result in self._sendMsg(finished):
yield result
- def _getFinished(self, masterSecret, expect_next_protocol=False, nextProto=None):
+ def _getFinished(self, masterSecret, expect_next_protocol=False, nextProto=None,
+ expect_channel_id=False):
#Get and check ChangeCipherSpec
for result in self._getMsg(ContentType.change_cipher_spec):
if result in (0,1):
@@ -1695,6 +1699,20 @@ class TLSConnection(TLSRecordLayer):
if nextProto:
self.next_proto = nextProto
+ #Server Finish - Are we waiting for a EncryptedExtensions?
+ if expect_channel_id:
+ for result in self._getMsg(ContentType.handshake, HandshakeType.encrypted_extensions):
+ if result in (0,1):
+ yield result
+ if result is None:
+ for result in self._sendError(AlertDescription.unexpected_message,
+ "Didn't get EncryptedExtensions message"):
+ yield result
+ encrypted_extensions = result
+ self.channel_id = result.channel_id_key
+ else:
+ self.channel_id = None
+
#Calculate verification data
verifyData = self._calcFinished(masterSecret, False)
diff --git a/third_party/tlslite/tlslite/tlsrecordlayer.py b/third_party/tlslite/tlslite/tlsrecordlayer.py
index 5fe7410..f18fcf5 100644
--- a/third_party/tlslite/tlslite/tlsrecordlayer.py
+++ b/third_party/tlslite/tlslite/tlsrecordlayer.py
@@ -806,6 +806,8 @@ class TLSRecordLayer(object):
yield Finished(self.version).parse(p)
elif subType == HandshakeType.next_protocol:
yield NextProtocol().parse(p)
+ elif subType == HandshakeType.encrypted_extensions:
+ yield EncryptedExtensions().parse(p)
else:
raise AssertionError()