blob: 976b62359310277d3f9d1ba3c6e73cd5cb37ec92 [file] [log] [blame]
# -*- coding: utf-8 -*-
"""
hyper/ssl_compat
~~~~~~~~~
Shoves pyOpenSSL into an API that looks like the standard Python 3.x ssl
module.
Currently exposes exactly those attributes, classes, and methods that we
actually use in hyper (all method signatures are complete, however). May be
expanded to something more general-purpose in the future.
"""
try:
import StringIO as BytesIO
except ImportError:
from io import BytesIO
import errno
import socket
import time
from OpenSSL import SSL as ossl
from service_identity.pyopenssl import verify_hostname as _verify
CERT_NONE = ossl.VERIFY_NONE
CERT_REQUIRED = ossl.VERIFY_PEER | ossl.VERIFY_FAIL_IF_NO_PEER_CERT
_OPENSSL_ATTRS = dict(
OP_NO_COMPRESSION='OP_NO_COMPRESSION',
PROTOCOL_TLSv1_2='TLSv1_2_METHOD',
PROTOCOL_SSLv23='SSLv23_METHOD',
)
for external, internal in _OPENSSL_ATTRS.items():
value = getattr(ossl, internal, None)
if value:
locals()[external] = value
OP_ALL = 0
# TODO: Find out the names of these other flags.
for bit in [31] + list(range(10)):
OP_ALL |= 1 << bit
HAS_NPN = True
def _proxy(method):
def inner(self, *args, **kwargs):
return getattr(self._conn, method)(*args, **kwargs)
return inner
# Referenced in hyper/http20/connection.py. These values come
# from the python ssl package, and must be defined in this file
# for hyper to work in python versions <2.7.9
SSL_ERROR_WANT_READ = 2
SSL_ERROR_WANT_WRITE = 3
# TODO missing some attributes
class SSLError(OSError):
pass
class CertificateError(SSLError):
pass
def verify_hostname(ssl_sock, server_hostname):
"""
A method nearly compatible with the stdlib's match_hostname.
"""
if isinstance(server_hostname, bytes):
server_hostname = server_hostname.decode('ascii')
return _verify(ssl_sock._conn, server_hostname)
class SSLSocket(object):
SSL_TIMEOUT = 3
SSL_RETRY = .01
def __init__(self, conn, server_side, do_handshake_on_connect,
suppress_ragged_eofs, server_hostname, check_hostname):
self._conn = conn
self._do_handshake_on_connect = do_handshake_on_connect
self._suppress_ragged_eofs = suppress_ragged_eofs
self._check_hostname = check_hostname
if server_side:
self._conn.set_accept_state()
else:
if server_hostname:
self._conn.set_tlsext_host_name(
server_hostname.encode('utf-8')
)
self._server_hostname = server_hostname
# FIXME does this override do_handshake_on_connect=False?
self._conn.set_connect_state()
if self.connected and self._do_handshake_on_connect:
self.do_handshake()
@property
def connected(self):
try:
self._conn.getpeername()
except socket.error as e:
if e.errno != errno.ENOTCONN:
# It's an exception other than the one we expected if we're not
# connected.
raise
return False
return True
# Lovingly stolen from CherryPy
# (http://svn.cherrypy.org/tags/cherrypy-3.2.1/cherrypy/wsgiserver/ssl_pyopenssl.py).
def _safe_ssl_call(self, suppress_ragged_eofs, call, *args, **kwargs):
"""Wrap the given call with SSL error-trapping."""
start = time.time()
while True:
try:
return call(*args, **kwargs)
except (ossl.WantReadError, ossl.WantWriteError):
# Sleep and try again. This is dangerous, because it means
# the rest of the stack has no way of differentiating
# between a "new handshake" error and "client dropped".
# Note this isn't an endless loop: there's a timeout below.
time.sleep(self.SSL_RETRY)
except ossl.Error as e:
if suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
return b''
raise socket.error(e.args[0])
if time.time() - start > self.SSL_TIMEOUT:
raise socket.timeout('timed out')
def connect(self, address):
self._conn.connect(address)
if self._do_handshake_on_connect:
self.do_handshake()
def do_handshake(self):
self._safe_ssl_call(False, self._conn.do_handshake)
if self._check_hostname:
verify_hostname(self, self._server_hostname)
def recv(self, bufsize, flags=None):
return self._safe_ssl_call(
self._suppress_ragged_eofs,
self._conn.recv,
bufsize,
flags
)
def recv_into(self, buffer, bufsize=None, flags=None):
# A temporary recv_into implementation. Should be replaced when
# PyOpenSSL has merged pyca/pyopenssl#121.
if bufsize is None:
bufsize = len(buffer)
data = self.recv(bufsize, flags)
data_len = len(data)
buffer[0:data_len] = data
return data_len
def send(self, data, flags=None):
return self._safe_ssl_call(False, self._conn.send, data, flags)
def sendall(self, data, flags=None):
return self._safe_ssl_call(False, self._conn.sendall, data, flags)
def selected_npn_protocol(self):
proto = self._conn.get_next_proto_negotiated()
if isinstance(proto, bytes):
proto = proto.decode('ascii')
return proto if proto else None
def selected_alpn_protocol(self):
proto = self._conn.get_alpn_proto_negotiated()
if isinstance(proto, bytes):
proto = proto.decode('ascii')
return proto if proto else None
def getpeercert(self):
def resolve_alias(alias):
return dict(
C='countryName',
ST='stateOrProvinceName',
L='localityName',
O='organizationName',
OU='organizationalUnitName',
CN='commonName',
).get(alias, alias)
def to_components(name):
# TODO Verify that these are actually *supposed* to all be
# single-element tuples, and that's not just a quirk of the
# examples I've seen.
return tuple(
[
(resolve_alias(k.decode('utf-8'), v.decode('utf-8')),)
for k, v in name.get_components()
]
)
# The standard getpeercert() takes the nice X509 object tree returned
# by OpenSSL and turns it into a dict according to some format it seems
# to have made up on the spot. Here, we do our best to emulate that.
cert = self._conn.get_peer_certificate()
result = dict(
issuer=to_components(cert.get_issuer()),
subject=to_components(cert.get_subject()),
version=cert.get_subject(),
serialNumber=cert.get_serial_number(),
notBefore=cert.get_notBefore(),
notAfter=cert.get_notAfter(),
)
# TODO extensions, including subjectAltName
# (see _decode_certificate in _ssl.c)
return result
# a dash of magic to reduce boilerplate
methods = ['accept', 'bind', 'close', 'getsockname', 'listen', 'fileno']
for method in methods:
locals()[method] = _proxy(method)
class SSLContext(object):
def __init__(self, protocol):
self.protocol = protocol
self._ctx = ossl.Context(protocol)
self.options = OP_ALL
self.check_hostname = False
self.npn_protos = []
@property
def options(self):
return self._options
@options.setter
def options(self, value):
self._options = value
self._ctx.set_options(value)
@property
def verify_mode(self):
return self._ctx.get_verify_mode()
@verify_mode.setter
def verify_mode(self, value):
# TODO verify exception is raised on failure
self._ctx.set_verify(
value, lambda conn, cert, errnum, errdepth, ok: ok
)
def set_default_verify_paths(self):
self._ctx.set_default_verify_paths()
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
# TODO factor out common code
if cafile is not None:
cafile = cafile.encode('utf-8')
if capath is not None:
capath = capath.encode('utf-8')
self._ctx.load_verify_locations(cafile, capath)
if cadata is not None:
self._ctx.load_verify_locations(BytesIO(cadata))
def load_cert_chain(self, certfile, keyfile=None, password=None):
self._ctx.use_certificate_file(certfile)
if password is not None:
self._ctx.set_passwd_cb(
lambda max_length, prompt_twice, userdata: password
)
self._ctx.use_privatekey_file(keyfile or certfile)
def set_npn_protocols(self, protocols):
self.protocols = list(map(lambda x: x.encode('ascii'), protocols))
def cb(conn, protos):
# Detect the overlapping set of protocols.
overlap = set(protos) & set(self.protocols)
# Select the option that comes last in the list in the overlap.
for p in self.protocols:
if p in overlap:
return p
else:
return b''
self._ctx.set_npn_select_callback(cb)
def set_alpn_protocols(self, protocols):
protocols = list(map(lambda x: x.encode('ascii'), protocols))
self._ctx.set_alpn_protos(protocols)
def wrap_socket(self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None):
conn = ossl.Connection(self._ctx, sock)
return SSLSocket(conn, server_side, do_handshake_on_connect,
suppress_ragged_eofs, server_hostname,
# TODO what if this is changed after the fact?
self.check_hostname)