blob: 472406edad71dd01311405f79d052f5214dafaf7 [file] [log] [blame]
diff --git a/third_party/tlslite/tlslite/TLSConnection.py b/third_party/tlslite/tlslite/TLSConnection.py
index f8811a9..e882e2c 100644
--- a/third_party/tlslite/tlslite/TLSConnection.py
+++ b/third_party/tlslite/tlslite/TLSConnection.py
@@ -611,6 +611,8 @@ class TLSConnection(TLSRecordLayer):
settings.cipherImplementations)
#Exchange ChangeCipherSpec and Finished messages
+ for result in self._getChangeCipherSpec():
+ yield result
for result in self._getFinished():
yield result
for result in self._sendFinished():
@@ -920,6 +922,8 @@ class TLSConnection(TLSRecordLayer):
#Exchange ChangeCipherSpec and Finished messages
for result in self._sendFinished():
yield result
+ for result in self._getChangeCipherSpec():
+ yield result
for result in self._getFinished():
yield result
@@ -1089,6 +1093,7 @@ class TLSConnection(TLSRecordLayer):
clientCertChain = None
serverCertChain = None #We may set certChain to this later
postFinishedError = None
+ doingChannelID = False
#Tentatively set version to most-desirable version, so if an error
#occurs parsing the ClientHello, this is what we'll use for the
@@ -1208,6 +1213,8 @@ class TLSConnection(TLSRecordLayer):
serverHello.create(self.version, serverRandom,
session.sessionID, session.cipherSuite,
certificateType)
+ serverHello.channel_id = clientHello.channel_id
+ doingChannelID = clientHello.channel_id
for result in self._sendMsg(serverHello):
yield result
@@ -1221,6 +1228,11 @@ class TLSConnection(TLSRecordLayer):
#Exchange ChangeCipherSpec and Finished messages
for result in self._sendFinished():
yield result
+ for result in self._getChangeCipherSpec():
+ yield result
+ if doingChannelID:
+ for result in self._getEncryptedExtensions():
+ yield result
for result in self._getFinished():
yield result
@@ -1399,8 +1411,12 @@ class TLSConnection(TLSRecordLayer):
#Send ServerHello, Certificate[, CertificateRequest],
#ServerHelloDone
msgs = []
- msgs.append(ServerHello().create(self.version, serverRandom,
- sessionID, cipherSuite, certificateType))
+ serverHello = ServerHello().create(
+ self.version, serverRandom,
+ sessionID, cipherSuite, certificateType)
+ serverHello.channel_id = clientHello.channel_id
+ doingChannelID = clientHello.channel_id
+ msgs.append(serverHello)
msgs.append(Certificate(certificateType).create(serverCertChain))
if reqCert and reqCAs:
msgs.append(CertificateRequest().create([], reqCAs))
@@ -1528,6 +1544,11 @@ class TLSConnection(TLSRecordLayer):
settings.cipherImplementations)
#Exchange ChangeCipherSpec and Finished messages
+ for result in self._getChangeCipherSpec():
+ yield result
+ if doingChannelID:
+ for result in self._getEncryptedExtensions():
+ yield result
for result in self._getFinished():
yield result
diff --git a/third_party/tlslite/tlslite/TLSRecordLayer.py b/third_party/tlslite/tlslite/TLSRecordLayer.py
index 1bbd09d..933b95a 100644
--- a/third_party/tlslite/tlslite/TLSRecordLayer.py
+++ b/third_party/tlslite/tlslite/TLSRecordLayer.py
@@ -714,6 +714,8 @@ class TLSRecordLayer:
self.version).parse(p)
elif subType == HandshakeType.finished:
yield Finished(self.version).parse(p)
+ elif subType == HandshakeType.encrypted_extensions:
+ yield EncryptedExtensions().parse(p)
else:
raise AssertionError()
@@ -1067,7 +1069,7 @@ class TLSRecordLayer:
for result in self._sendMsg(finished):
yield result
- def _getFinished(self):
+ def _getChangeCipherSpec(self):
#Get and check ChangeCipherSpec
for result in self._getMsg(ContentType.change_cipher_spec):
if result in (0,1):
@@ -1082,6 +1084,15 @@ class TLSRecordLayer:
#Switch to pending read state
self._changeReadState()
+ def _getEncryptedExtensions(self):
+ for result in self._getMsg(ContentType.handshake,
+ HandshakeType.encrypted_extensions):
+ if result in (0,1):
+ yield result
+ encrypted_extensions = result
+ self.channel_id = encrypted_extensions.channel_id_key
+
+ def _getFinished(self):
#Calculate verification data
verifyData = self._calcFinished(False)
diff --git a/third_party/tlslite/tlslite/constants.py b/third_party/tlslite/tlslite/constants.py
index 04302c0..e357dd0 100644
--- a/third_party/tlslite/tlslite/constants.py
+++ b/third_party/tlslite/tlslite/constants.py
@@ -22,6 +22,7 @@ class HandshakeType:
certificate_verify = 15
client_key_exchange = 16
finished = 20
+ encrypted_extensions = 203
class ContentType:
change_cipher_spec = 20
@@ -30,6 +31,9 @@ class ContentType:
application_data = 23
all = (20,21,22,23)
+class ExtensionType:
+ channel_id = 30031
+
class AlertLevel:
warning = 1
fatal = 2
diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
index dc6ed32..fa4d817 100644
--- a/third_party/tlslite/tlslite/messages.py
+++ b/third_party/tlslite/tlslite/messages.py
@@ -130,6 +130,7 @@ class ClientHello(HandshakeMsg):
self.certificate_types = [CertificateType.x509]
self.compression_methods = [] # a list of 8-bit values
self.srp_username = None # a string
+ self.channel_id = False
def create(self, version, random, session_id, cipher_suites,
certificate_types=None, srp_username=None):
@@ -174,6 +175,8 @@ class ClientHello(HandshakeMsg):
self.srp_username = bytesToString(p.getVarBytes(1))
elif extType == 7:
self.certificate_types = p.getVarList(1, 1)
+ elif extType == ExtensionType.channel_id:
+ self.channel_id = True
else:
p.getFixBytes(extLength)
soFar += 4 + extLength
@@ -220,6 +223,7 @@ class ServerHello(HandshakeMsg):
self.cipher_suite = 0
self.certificate_type = CertificateType.x509
self.compression_method = 0
+ self.channel_id = False
def create(self, version, random, session_id, cipher_suite,
certificate_type):
@@ -266,6 +270,9 @@ class ServerHello(HandshakeMsg):
CertificateType.x509:
extLength += 5
+ if self.channel_id:
+ extLength += 4
+
if extLength != 0:
w.add(extLength, 2)
@@ -275,6 +282,10 @@ class ServerHello(HandshakeMsg):
w.add(1, 2)
w.add(self.certificate_type, 1)
+ if self.channel_id:
+ w.add(ExtensionType.channel_id, 2)
+ w.add(0, 2)
+
return HandshakeMsg.postWrite(self, w, trial)
class Certificate(HandshakeMsg):
@@ -567,6 +578,28 @@ class Finished(HandshakeMsg):
w.addFixSeq(self.verify_data, 1)
return HandshakeMsg.postWrite(self, w, trial)
+class EncryptedExtensions(HandshakeMsg):
+ def __init__(self):
+ self.channel_id_key = None
+ self.channel_id_proof = None
+
+ def parse(self, p):
+ p.startLengthCheck(3)
+ soFar = 0
+ while soFar != p.lengthCheck:
+ extType = p.get(2)
+ extLength = p.get(2)
+ if extType == ExtensionType.channel_id:
+ if extLength != 32*4:
+ raise SyntaxError()
+ self.channel_id_key = p.getFixBytes(64)
+ self.channel_id_proof = p.getFixBytes(64)
+ else:
+ p.getFixBytes(extLength)
+ soFar += 4 + extLength
+ p.stopLengthCheck()
+ return self
+
class ApplicationData(Msg):
def __init__(self):
self.contentType = ContentType.application_data