blob: 33d47d6df1a5f2681f6380df53460dec5dddccff [file] [log] [blame]
diff --git a/third_party/tlslite/tlslite/constants.py b/third_party/tlslite/tlslite/constants.py
index 715def9..e9743e4 100644
--- a/third_party/tlslite/tlslite/constants.py
+++ b/third_party/tlslite/tlslite/constants.py
@@ -54,6 +54,7 @@ class ExtensionType: # RFC 6066 / 4366
status_request = 5 # RFC 6066 / 4366
srp = 12 # RFC 5054
cert_type = 9 # RFC 6091
+ alpn = 16 # RFC 7301
signed_cert_timestamps = 18 # RFC 6962
extended_master_secret = 23 # RFC 7627
token_binding = 24 # draft-ietf-tokbind-negotiation
diff --git a/third_party/tlslite/tlslite/handshakesettings.py b/third_party/tlslite/tlslite/handshakesettings.py
index d7be5b3..69fc6f4 100644
--- a/third_party/tlslite/tlslite/handshakesettings.py
+++ b/third_party/tlslite/tlslite/handshakesettings.py
@@ -128,6 +128,12 @@ class HandshakeSettings(object):
Note that TACK support is not standardized by IETF and uses a temporary
TLS Extension number, so should NOT be used in production software.
+
+ @type alpnProtos: list of strings.
+ @param alpnProtos: A list of supported upper layer protocols to use in the
+ Application-Layer Protocol Negotiation Extension (RFC 7301). For the
+ client, the order does not matter. For the server, the list is in
+ decreasing order of preference.
"""
def __init__(self):
self.minKeySize = 1023
@@ -146,6 +152,7 @@ class HandshakeSettings(object):
self.enableChannelID = True
self.enableExtendedMasterSecret = True
self.supportedTokenBindingParams = []
+ self.alpnProtos = None
# Validates the min/max fields, and certificateTypes
# Filters out unsupported cipherNames and cipherImplementations
@@ -166,6 +173,7 @@ class HandshakeSettings(object):
other.enableChannelID = self.enableChannelID
other.enableExtendedMasterSecret = self.enableExtendedMasterSecret
other.supportedTokenBindingParams = self.supportedTokenBindingParams
+ other.alpnProtos = self.alpnProtos;
if not cipherfactory.tripleDESPresent:
other.cipherNames = [e for e in self.cipherNames if e != "3des"]
diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
index 5762ac6..1ce9320 100644
--- a/third_party/tlslite/tlslite/messages.py
+++ b/third_party/tlslite/tlslite/messages.py
@@ -18,6 +18,27 @@ from .x509 import X509
from .x509certchain import X509CertChain
from .utils.tackwrapper import *
+def parse_next_protos(b):
+ protos = []
+ while True:
+ if len(b) == 0:
+ break
+ l = b[0]
+ b = b[1:]
+ if len(b) < l:
+ raise BadNextProtos(len(b))
+ protos.append(b[:l])
+ b = b[l:]
+ return protos
+
+def next_protos_encoded(protocol_list):
+ b = bytearray()
+ for e in protocol_list:
+ if len(e) > 255 or len(e) == 0:
+ raise BadNextProtos(len(e))
+ b += bytearray( [len(e)] ) + bytearray(e)
+ return b
+
class RecordHeader3(object):
def __init__(self):
self.type = 0
@@ -111,6 +132,7 @@ class ClientHello(HandshakeMsg):
self.compression_methods = [] # a list of 8-bit values
self.srp_username = None # a string
self.tack = False
+ self.alpn_protos_advertised = None
self.supports_npn = False
self.server_name = bytearray(0)
self.channel_id = False
@@ -121,7 +143,8 @@ class ClientHello(HandshakeMsg):
def create(self, version, random, session_id, cipher_suites,
certificate_types=None, srpUsername=None,
- tack=False, supports_npn=False, serverName=None):
+ tack=False, alpn_protos_advertised=None,
+ supports_npn=False, serverName=None):
self.client_version = version
self.random = random
self.session_id = session_id
@@ -131,6 +154,7 @@ class ClientHello(HandshakeMsg):
if srpUsername:
self.srp_username = bytearray(srpUsername, "utf-8")
self.tack = tack
+ self.alpn_protos_advertised = alpn_protos_advertised
self.supports_npn = supports_npn
if serverName:
self.server_name = bytearray(serverName, "utf-8")
@@ -171,6 +195,11 @@ class ClientHello(HandshakeMsg):
self.certificate_types = p.getVarList(1, 1)
elif extType == ExtensionType.tack:
self.tack = True
+ elif extType == ExtensionType.alpn:
+ structLength = p.get(2)
+ if structLength + 2 != extLength:
+ raise SyntaxError()
+ self.alpn_protos_advertised = parse_next_protos(p.getFixBytes(structLength))
elif extType == ExtensionType.supports_npn:
self.supports_npn = True
elif extType == ExtensionType.server_name:
@@ -243,6 +272,12 @@ class ClientHello(HandshakeMsg):
w2.add(ExtensionType.srp, 2)
w2.add(len(self.srp_username)+1, 2)
w2.addVarSeq(self.srp_username, 1, 1)
+ if self.alpn_protos_advertised is not None:
+ encoded_alpn_protos_advertised = next_protos_encoded(self.alpn_protos_advertised)
+ w2.add(ExtensionType.alpn, 2)
+ w2.add(len(encoded_alpn_protos_advertised) + 2, 2)
+ w2.add(len(encoded_alpn_protos_advertised), 2)
+ w2.addFixSeq(encoded_alpn_protos_advertised, 1)
if self.supports_npn:
w2.add(ExtensionType.supports_npn, 2)
w2.add(0, 2)
@@ -267,6 +302,13 @@ class BadNextProtos(Exception):
def __str__(self):
return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
+class InvalidALPNResponse(Exception):
+ def __init__(self, l):
+ self.length = l
+
+ def __str__(self):
+ return 'ALPN server response protocol list has invalid length %d. It must be of length one.' % self.length
+
class ServerHello(HandshakeMsg):
def __init__(self):
HandshakeMsg.__init__(self, HandshakeType.server_hello)
@@ -277,6 +319,7 @@ class ServerHello(HandshakeMsg):
self.certificate_type = CertificateType.x509
self.compression_method = 0
self.tackExt = None
+ self.alpn_proto_selected = None
self.next_protos_advertised = None
self.next_protos = None
self.channel_id = False
@@ -286,7 +329,8 @@ class ServerHello(HandshakeMsg):
self.status_request = False
def create(self, version, random, session_id, cipher_suite,
- certificate_type, tackExt, next_protos_advertised):
+ certificate_type, tackExt, alpn_proto_selected,
+ next_protos_advertised):
self.server_version = version
self.random = random
self.session_id = session_id
@@ -294,6 +338,7 @@ class ServerHello(HandshakeMsg):
self.certificate_type = certificate_type
self.compression_method = 0
self.tackExt = tackExt
+ self.alpn_proto_selected = alpn_proto_selected
self.next_protos_advertised = next_protos_advertised
return self
@@ -316,35 +361,22 @@ class ServerHello(HandshakeMsg):
self.certificate_type = p.get(1)
elif extType == ExtensionType.tack and tackpyLoaded:
self.tackExt = TackExtension(p.getFixBytes(extLength))
+ elif extType == ExtensionType.alpn:
+ structLength = p.get(2)
+ if structLength + 2 != extLength:
+ raise SyntaxError()
+ alpn_protos = parse_next_protos(p.getFixBytes(structLength))
+ if len(alpn_protos) != 1:
+ raise InvalidALPNResponse(len(alpn_protos));
+ self.alpn_proto_selected = alpn_protos[0]
elif extType == ExtensionType.supports_npn:
- self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength))
+ self.next_protos = parse_next_protos(p.getFixBytes(extLength))
else:
p.getFixBytes(extLength)
soFar += 4 + extLength
p.stopLengthCheck()
return self
- def __parse_next_protos(self, b):
- protos = []
- while True:
- if len(b) == 0:
- break
- l = b[0]
- b = b[1:]
- if len(b) < l:
- raise BadNextProtos(len(b))
- protos.append(b[:l])
- b = b[l:]
- return protos
-
- def __next_protos_encoded(self):
- b = bytearray()
- for e in self.next_protos_advertised:
- if len(e) > 255 or len(e) == 0:
- raise BadNextProtos(len(e))
- b += bytearray( [len(e)] ) + bytearray(e)
- return b
-
def write(self):
w = Writer()
w.add(self.server_version[0], 1)
@@ -365,8 +397,15 @@ class ServerHello(HandshakeMsg):
w2.add(ExtensionType.tack, 2)
w2.add(len(b), 2)
w2.bytes += b
+ if self.alpn_proto_selected is not None:
+ alpn_protos_single_element_list = [self.alpn_proto_selected]
+ encoded_alpn_protos_advertised = next_protos_encoded(alpn_protos_single_element_list)
+ w2.add(ExtensionType.alpn, 2)
+ w2.add(len(encoded_alpn_protos_advertised) + 2, 2)
+ w2.add(len(encoded_alpn_protos_advertised), 2)
+ w2.addFixSeq(encoded_alpn_protos_advertised, 1)
if self.next_protos_advertised is not None:
- encoded_next_protos_advertised = self.__next_protos_encoded()
+ encoded_next_protos_advertised = next_protos_encoded(self.next_protos_advertised)
w2.add(ExtensionType.supports_npn, 2)
w2.add(len(encoded_next_protos_advertised), 2)
w2.addFixSeq(encoded_next_protos_advertised, 1)
diff --git a/third_party/tlslite/tlslite/tlsconnection.py b/third_party/tlslite/tlslite/tlsconnection.py
index 41aab85..de5d580 100644
--- a/third_party/tlslite/tlslite/tlsconnection.py
+++ b/third_party/tlslite/tlslite/tlsconnection.py
@@ -495,6 +495,10 @@ class TLSConnection(TLSRecordLayer):
settings = HandshakeSettings()
settings = settings._filter()
+ if settings.alpnProtos is not None:
+ if len(settings.alpnProtos) == 0:
+ raise ValueError("Caller passed no alpnProtos")
+
if clientCertChain:
if not isinstance(clientCertChain, X509CertChain):
raise ValueError("Unrecognized certificate type")
@@ -651,7 +655,8 @@ class TLSConnection(TLSRecordLayer):
session.sessionID, cipherSuites,
certificateTypes,
session.srpUsername,
- reqTack, nextProtos is not None,
+ reqTack, settings.alpnProtos,
+ nextProtos is not None,
session.serverName)
#Or send ClientHello (without)
@@ -661,7 +666,8 @@ class TLSConnection(TLSRecordLayer):
bytearray(0), cipherSuites,
certificateTypes,
srpUsername,
- reqTack, nextProtos is not None,
+ reqTack, settings.alpnProtos,
+ nextProtos is not None,
serverName)
for result in self._sendMsg(clientHello):
yield result
@@ -714,6 +720,16 @@ class TLSConnection(TLSRecordLayer):
AlertDescription.illegal_parameter,
"Server responded with unrequested Tack Extension"):
yield result
+ if serverHello.alpn_proto_selected and not clientHello.alpn_protos_advertised:
+ for result in self._sendError(\
+ AlertDescription.illegal_parameter,
+ "Server responded with unrequested ALPN Extension"):
+ yield result
+ if serverHello.alpn_proto_selected and serverHello.next_protos:
+ for result in self._sendError(\
+ AlertDescription.illegal_parameter,
+ "Server responded with both ALPN and NPN extension"):
+ yield result
if serverHello.next_protos and not clientHello.supports_npn:
for result in self._sendError(\
AlertDescription.illegal_parameter,
@@ -1315,6 +1331,15 @@ class TLSConnection(TLSRecordLayer):
else:
sessionID = bytearray(0)
+ alpn_proto_selected = None
+ if (clientHello.alpn_protos_advertised is not None
+ and settings.alpnProtos is not None):
+ for proto in settings.alpnProtos:
+ if proto in clientHello.alpn_protos_advertised:
+ alpn_proto_selected = proto
+ nextProtos = None
+ break;
+
if not clientHello.supports_npn:
nextProtos = None
@@ -1330,6 +1355,7 @@ class TLSConnection(TLSRecordLayer):
serverHello = ServerHello()
serverHello.create(self.version, getRandomBytes(32), sessionID, \
cipherSuite, CertificateType.x509, tackExt,
+ alpn_proto_selected,
nextProtos)
serverHello.channel_id = \
clientHello.channel_id and settings.enableChannelID
@@ -1500,6 +1526,14 @@ class TLSConnection(TLSRecordLayer):
else:
assert(False)
+ alpn_proto_selected = None
+ if (clientHello.alpn_protos_advertised is not None
+ and settings.alpnProtos is not None):
+ for proto in settings.alpnProtos:
+ if proto in clientHello.alpn_protos_advertised:
+ alpn_proto_selected = proto
+ break;
+
#If resumption was requested and we have a session cache...
if clientHello.session_id and sessionCache:
session = None
@@ -1540,7 +1574,8 @@ class TLSConnection(TLSRecordLayer):
serverHello = ServerHello()
serverHello.create(self.version, getRandomBytes(32),
session.sessionID, session.cipherSuite,
- CertificateType.x509, None, None)
+ CertificateType.x509, None,
+ alpn_proto_selected, None)
serverHello.extended_master_secret = \
clientHello.extended_master_secret and \
settings.enableExtendedMasterSecret