blob: b8a168dce24b4f1c5837d7184743278c73a557ff [file] [log] [blame]
# Copyright (c) 2014 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""
A http client with support for https connections with certificate verification.
The verification is based on http://tools.ietf.org/html/rfc6125#section-6.4.3
and the code is from Lib/ssl.py in python3:
http://hg.python.org/cpython/file/4dac45f88d45/Lib/ssl.py
One use case is to download Chromium DEPS file in a secure way:
https://src.chromium.org/chrome/trunk/src/DEPS
Notice: python 2.7 or newer is required.
"""
import cookielib
import httplib
import os
import re
import socket
import ssl
import time
import urllib
import urllib2
import http_client
_SCRIPT_DIR = os.path.dirname(__file__)
_TRUSTED_ROOT_CERTS = os.path.join(_SCRIPT_DIR, 'cacert.pem')
class CertificateError(ValueError):
pass
def _DNSNameMatch(dn, hostname, max_wildcards=1):
"""Matching according to RFC 6125, section 6.4.3
http://tools.ietf.org/html/rfc6125#section-6.4.3
"""
pats = []
if not dn:
return False
parts = dn.split(r'.')
leftmost = parts[0]
remainder = parts[1:]
wildcards = leftmost.count('*')
if wildcards > max_wildcards:
# Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survery of established
# policy among SSL implementations showed it to be a
# reasonable choice.
raise CertificateError(
'too many wildcards in certificate DNS name: ' + repr(dn))
# speed up common case w/o wildcards
if not wildcards:
return dn.lower() == hostname.lower()
# RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which
# the wildcard character comprises a label other than the left-most label.
if leftmost == '*':
# When '*' is a fragment by itself, it matches a non-empty dotless
# fragment.
pats.append('[^.]+')
elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
# RFC 6125, section 6.4.3, subitem 3.
# The client SHOULD NOT attempt to match a presented identifier
# where the wildcard character is embedded within an A-label or
# U-label of an internationalized domain name.
pats.append(re.escape(leftmost))
else:
# Otherwise, '*' matches any dotless string, e.g. www*
pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
# add the remaining fragments, ignore any wildcards
for frag in remainder:
pats.append(re.escape(frag))
pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
return pat.match(hostname)
def _MatchHostname(cert, hostname):
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
rules are followed, but IP addresses are not accepted for *hostname*.
CertificateError is raised on failure. On success, the function
returns nothing.
"""
if not cert:
raise ValueError('empty or no certificate, match_hostname needs a '
'SSL socket or SSL context with either '
'CERT_OPTIONAL or CERT_REQUIRED')
dnsnames = []
san = cert.get('subjectAltName', ())
for key, value in san:
if key == 'DNS':
if _DNSNameMatch(value, hostname):
return
dnsnames.append(value)
if not dnsnames:
# The subject is only checked when there is no dNSName entry
# in subjectAltName
for sub in cert.get('subject', ()):
for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == 'commonName':
if _DNSNameMatch(value, hostname):
return
dnsnames.append(value)
if len(dnsnames) > 1:
raise CertificateError('hostname %r doesn\'t match either of %s'
% (hostname, ', '.join(map(repr, dnsnames))))
elif len(dnsnames) == 1:
raise CertificateError('hostname %r doesn\'t match %r'
% (hostname, dnsnames[0]))
else:
raise CertificateError('no appropriate commonName or '
'subjectAltName fields were found')
class HTTPSConnection(httplib.HTTPSConnection):
def __init__(self, host, root_certs=_TRUSTED_ROOT_CERTS, **kwargs):
self.root_certs = root_certs
httplib.HTTPSConnection.__init__(self, host, **kwargs)
def connect(self):
# Overrides for certificate verification.
args = [(self.host, self.port), self.timeout,]
if self.source_address:
args.append(self.source_address)
sock = socket.create_connection(*args)
if self._tunnel_host:
self.sock = sock
self._tunnel()
# Wrap the socket for verification with the root certs.
kwargs = {}
if self.root_certs is not None:
kwargs.update(cert_reqs=ssl.CERT_REQUIRED, ca_certs=self.root_certs)
self.sock = ssl.wrap_socket(sock, **kwargs)
# Check hostname.
try:
_MatchHostname(self.sock.getpeercert(), self.host)
except CertificateError:
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()
raise
class HTTPSHandler(urllib2.HTTPSHandler):
def __init__(self, root_certs=_TRUSTED_ROOT_CERTS):
urllib2.HTTPSHandler.__init__(self)
self.root_certs = root_certs
def https_open(self, req):
# Pass a reference to the function below so that verification against
# trusted root certs could be injected.
return self.do_open(self.GetConnection, req)
def GetConnection(self, host, **kwargs):
params = dict(root_certs=self.root_certs)
params.update(kwargs)
return HTTPSConnection(host, **params)
def _SendRequest(url, timeout=None):
"""Send request to the given https url, and return the server response.
Args:
url: The https url to send request to.
Returns:
An integer: http code of the response.
A string: content of the response.
Raises:
CertificateError: Certificate verification fails.
"""
if not url:
return None, None
handlers = []
if url.startswith('https://'):
# HTTPSHandler has to go first, because we don't want to send secure cookies
# to a man in the middle.
handlers.append(HTTPSHandler())
cookie_file = os.environ.get('COOKIE_FILE')
if cookie_file and os.path.exists(cookie_file):
handlers.append(
urllib2.HTTPCookieProcessor(cookielib.MozillaCookieJar(cookie_file)))
url_opener = urllib2.build_opener(*handlers)
status_code = None
content = None
try:
response = url_opener.open(url, timeout=timeout)
status_code = response.code
content = response.read()
except urllib2.HTTPError as e:
status_code = e.code
content = None
except (ssl.SSLError, httplib.BadStatusLine, IOError):
status_code = -1
content = None
return status_code, content
class HttpClientLocal(http_client.HttpClient):
"""This http client is used locally in a workstation, GCE VMs, etc."""
@staticmethod
def Get(url, params={}, timeout=120, retries=5, retry_interval=0.5,
retry_if_not=None):
if params:
url = '%s?%s' % (url, urllib.urlencode(params))
count = 0
while True:
count += 1
status_code, content = _SendRequest(url, timeout=timeout)
if status_code == 200:
return status_code, content
if retry_if_not and status_code == retry_if_not:
return status_code, content
if count < retries:
time.sleep(retry_interval)
else:
return status_code, content
# Should never be reached.
return status_code, content