| diff --git a/ssl/ssl.h b/ssl/ssl.h |
| index 34142fc..e2d1b09 100644 |
| --- a/ssl/ssl.h |
| +++ b/ssl/ssl.h |
| @@ -803,6 +803,16 @@ SSL_IMPORT SECStatus SSL_ReHandshakeWithTimeout(PRFileDesc *fd, |
| PRBool flushCache, |
| PRIntervalTime timeout); |
| |
| +/* Returns a SECItem containing the certificate_types field of the |
| +** CertificateRequest message. Each byte of the data is a TLS |
| +** ClientCertificateType value, and they are ordered from most preferred to |
| +** least. This function should only be called from the |
| +** SSL_GetClientAuthDataHook callback, and will return NULL if called at any |
| +** other time. The returned value is valid only until the callback returns, and |
| +** should not be freed. |
| +*/ |
| +SSL_IMPORT const SECItem * |
| +SSL_GetRequestedClientCertificateTypes(PRFileDesc *fd); |
| |
| #ifdef SSL_DEPRECATED_FUNCTION |
| /* deprecated! |
| diff --git a/ssl/ssl3con.c b/ssl/ssl3con.c |
| index 40ae885..cb59cc1 100644 |
| --- a/ssl/ssl3con.c |
| +++ b/ssl/ssl3con.c |
| @@ -7045,6 +7045,9 @@ ssl3_HandleCertificateRequest(sslSocket *ss, SSL3Opaque *b, PRUint32 length) |
| if (rv != SECSuccess) |
| goto loser; /* malformed, alert has been sent */ |
| |
| + PORT_Assert(!ss->requestedCertTypes); |
| + ss->requestedCertTypes = &cert_types; |
| + |
| if (isTLS12) { |
| rv = ssl3_ConsumeHandshakeVariable(ss, &algorithms, 2, &b, &length); |
| if (rv != SECSuccess) |
| @@ -7246,6 +7249,7 @@ loser: |
| PORT_SetError(errCode); |
| rv = SECFailure; |
| done: |
| + ss->requestedCertTypes = NULL; |
| if (arena != NULL) |
| PORT_FreeArena(arena, PR_FALSE); |
| #ifdef NSS_PLATFORM_CLIENT_AUTH |
| diff --git a/ssl/sslimpl.h b/ssl/sslimpl.h |
| index cda1869..9f59f5a 100644 |
| --- a/ssl/sslimpl.h |
| +++ b/ssl/sslimpl.h |
| @@ -1231,6 +1231,10 @@ struct sslSocketStr { |
| unsigned int sizeCipherSpecs; |
| const unsigned char * preferredCipher; |
| |
| + /* TLS ClientCertificateTypes requested during HandleCertificateRequest. */ |
| + /* Will be NULL at all other times. */ |
| + const SECItem *requestedCertTypes; |
| + |
| ssl3KeyPair * stepDownKeyPair; /* RSA step down keys */ |
| |
| /* Callbacks */ |
| diff --git a/ssl/sslsock.c b/ssl/sslsock.c |
| index 688f399..a939781 100644 |
| --- a/ssl/sslsock.c |
| +++ b/ssl/sslsock.c |
| @@ -1911,6 +1911,20 @@ SSL_HandshakeResumedSession(PRFileDesc *fd, PRBool *handshake_resumed) { |
| return SECSuccess; |
| } |
| |
| +const SECItem * |
| +SSL_GetRequestedClientCertificateTypes(PRFileDesc *fd) |
| +{ |
| + sslSocket *ss = ssl_FindSocket(fd); |
| + |
| + if (!ss) { |
| + SSL_DBG(("%d: SSL[%d]: bad socket in " |
| + "SSL_GetRequestedClientCertificateTypes", SSL_GETPID(), fd)); |
| + return NULL; |
| + } |
| + |
| + return ss->requestedCertTypes; |
| +} |
| + |
| /************************************************************************/ |
| /* The following functions are the TOP LEVEL SSL functions. |
| ** They all get called through the NSPRIOMethods table below. |
| @@ -2989,6 +3003,7 @@ ssl_NewSocket(PRBool makeLocks, SSLProtocolVariant protocolVariant) |
| sc->serverKeyBits = 0; |
| ss->certStatusArray[i] = NULL; |
| } |
| + ss->requestedCertTypes = NULL; |
| ss->stepDownKeyPair = NULL; |
| ss->dbHandle = CERT_GetDefaultCertDB(); |
| |