blob: bbd387ea5e492e76987ad5e42c3d63064fefef4e [file] [log] [blame]
# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Spins up a trivial HTTP cgi form listener in a thread.
This HTTPThread class is a utility for use with test cases that
need to call back to the Autotest test case with some form value, e.g.
http://localhost:nnnn/?status="Browser started!"
"""
import cgi, errno, logging, os, posixpath, SimpleHTTPServer, socket, ssl, sys
import threading, urllib, urlparse
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
from SocketServer import BaseServer, ThreadingMixIn
def _handle_http_errors(func):
"""Decorator function for cleaner presentation of certain exceptions."""
def wrapper(self):
try:
func(self)
except IOError, e:
if e.errno == errno.EPIPE or e.errno == errno.ECONNRESET:
# Instead of dumping a stack trace, a single line is sufficient.
self.log_error(str(e))
else:
raise
return wrapper
class FormHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
"""Implements a form handler (for POST requests only) which simply
echoes the key=value parameters back in the response.
If the form submission is a file upload, the file will be written
to disk with the name contained in the 'filename' field.
"""
SimpleHTTPServer.SimpleHTTPRequestHandler.extensions_map.update({
'.webm': 'video/webm',
})
# Override the default logging methods to use the logging module directly.
def log_error(self, format, *args):
logging.warning("(httpd error) %s - - [%s] %s\n" %
(self.address_string(), self.log_date_time_string(),
format%args))
def log_message(self, format, *args):
logging.debug("%s - - [%s] %s\n" %
(self.address_string(), self.log_date_time_string(),
format%args))
@_handle_http_errors
def do_POST(self):
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={'REQUEST_METHOD': 'POST',
'CONTENT_TYPE': self.headers['Content-Type']})
# You'd think form.keys() would just return [], like it does for empty
# python dicts; you'd be wrong. It raises TypeError if called when it
# has no keys.
if form:
for field in form.keys():
field_item = form[field]
self.server._form_entries[field] = field_item.value
path = urlparse.urlparse(self.path)[2]
if path in self.server._url_handlers:
self.server._url_handlers[path](self, form)
else:
# Echo back information about what was posted in the form.
self.write_post_response(form)
self._fire_event()
def write_post_response(self, form):
"""Called to fill out the response to an HTTP POST.
Override this class to give custom responses.
"""
# Send response boilerplate
self.send_response(200)
self.end_headers()
self.wfile.write('Hello from Autotest!\nClient: %s\n' %
str(self.client_address))
self.wfile.write('Request for path: %s\n' % self.path)
self.wfile.write('Got form data:\n')
# See the note in do_POST about form.keys().
if form:
for field in form.keys():
field_item = form[field]
if field_item.filename:
# The field contains an uploaded file
upload = field_item.file.read()
self.wfile.write('\tUploaded %s (%d bytes)<br>' %
(field, len(upload)))
# Write submitted file to specified filename.
file(field_item.filename, 'w').write(upload)
del upload
else:
self.wfile.write('\t%s=%s<br>' % (field, form[field].value))
def translate_path(self, path):
"""Override SimpleHTTPRequestHandler's translate_path to serve
from arbitrary docroot
"""
# abandon query parameters
path = urlparse.urlparse(path)[2]
path = posixpath.normpath(urllib.unquote(path))
words = path.split('/')
words = filter(None, words)
path = self.server.docroot
for word in words:
drive, word = os.path.splitdrive(word)
head, word = os.path.split(word)
if word in (os.curdir, os.pardir): continue
path = os.path.join(path, word)
logging.debug('Translated path: %s', path)
return path
def _fire_event(self):
wait_urls = self.server._wait_urls
if self.path in wait_urls:
_, e = wait_urls[self.path]
e.set()
del wait_urls[self.path]
else:
logging.debug('URL %s not in watch list' % self.path)
@_handle_http_errors
def do_GET(self):
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={'REQUEST_METHOD': 'GET'})
split_url = urlparse.urlsplit(self.path)
path = split_url[2]
# Strip off query parameters to ensure that the url path
# matches any registered events.
self.path = path
args = urlparse.parse_qs(split_url[3])
if path in self.server._url_handlers:
self.server._url_handlers[path](self, args)
else:
SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
self._fire_event()
@_handle_http_errors
def do_HEAD(self):
SimpleHTTPServer.SimpleHTTPRequestHandler.do_HEAD(self)
class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
def __init__(self, server_address, HandlerClass):
HTTPServer.__init__(self, server_address, HandlerClass)
class HTTPListener(object):
# Point default docroot to a non-existent directory (instead of None) to
# avoid exceptions when page content is served through handlers only.
def __init__(self, port=0, docroot='/_', wait_urls={}, url_handlers={}):
self._server = ThreadedHTTPServer(('', port), FormHandler)
self.config_server(self._server, docroot, wait_urls, url_handlers)
def config_server(self, server, docroot, wait_urls, url_handlers):
# Stuff some convenient data fields into the server object.
self._server.docroot = docroot
self._server._wait_urls = wait_urls
self._server._url_handlers = url_handlers
self._server._form_entries = {}
self._server_thread = threading.Thread(
target=self._server.serve_forever)
def add_wait_url(self, url='/', matchParams={}):
e = threading.Event()
self._server._wait_urls[url] = (matchParams, e)
return e
def add_url_handler(self, url, handler_func):
self._server._url_handlers[url] = handler_func
def clear_form_entries(self):
self._server._form_entries = {}
def get_form_entries(self):
"""Returns a dictionary of all field=values recieved by the server.
"""
return self._server._form_entries
def run(self):
logging.debug('http server on %s:%d' %
(self._server.server_name, self._server.server_port))
self._server_thread.start()
def stop(self):
self._server.shutdown()
self._server.socket.close()
self._server_thread.join()
class SecureHTTPServer(ThreadingMixIn, HTTPServer):
def __init__(self, server_address, HandlerClass, cert_path, key_path):
_socket = socket.socket(self.address_family, self.socket_type)
self.socket = ssl.wrap_socket(_socket,
server_side=True,
ssl_version=ssl.PROTOCOL_TLSv1,
certfile=cert_path,
keyfile=key_path)
BaseServer.__init__(self, server_address, HandlerClass)
self.server_bind()
self.server_activate()
class SecureHTTPRequestHandler(FormHandler):
def setup(self):
self.connection = self.request
self.rfile = socket._fileobject(self.request, 'rb', self.rbufsize)
self.wfile = socket._fileobject(self.request, 'wb', self.wbufsize)
# Override the default logging methods to use the logging module directly.
def log_error(self, format, *args):
logging.warning("(httpd error) %s - - [%s] %s\n" %
(self.address_string(), self.log_date_time_string(),
format%args))
def log_message(self, format, *args):
logging.debug("%s - - [%s] %s\n" %
(self.address_string(), self.log_date_time_string(),
format%args))
class SecureHTTPListener(HTTPListener):
def __init__(self,
cert_path='/etc/login_trust_root.pem',
key_path='/etc/mock_server.key',
port=0,
docroot='/_',
wait_urls={},
url_handlers={}):
self._server = SecureHTTPServer(('', port),
SecureHTTPRequestHandler,
cert_path,
key_path)
self.config_server(self._server, docroot, wait_urls, url_handlers)
def getsockname(self):
return self._server.socket.getsockname()