| ############################################################################## |
| # |
| # Copyright (c) 2001, 2002 Zope Foundation and Contributors. |
| # All Rights Reserved. |
| # |
| # This software is subject to the provisions of the Zope Public License, |
| # Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. |
| # THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED |
| # WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
| # WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS |
| # FOR A PARTICULAR PURPOSE. |
| # |
| ############################################################################## |
| |
| from collections import deque |
| import socket |
| import sys |
| import threading |
| import time |
| |
| from .buffers import ReadOnlyFileBasedBuffer |
| from .utilities import build_http_date, logger, queue_logger |
| |
| rename_headers = { # or keep them without the HTTP_ prefix added |
| "CONTENT_LENGTH": "CONTENT_LENGTH", |
| "CONTENT_TYPE": "CONTENT_TYPE", |
| } |
| |
| hop_by_hop = frozenset( |
| ( |
| "connection", |
| "keep-alive", |
| "proxy-authenticate", |
| "proxy-authorization", |
| "te", |
| "trailers", |
| "transfer-encoding", |
| "upgrade", |
| ) |
| ) |
| |
| |
| class ThreadedTaskDispatcher: |
| """A Task Dispatcher that creates a thread for each task.""" |
| |
| stop_count = 0 # Number of threads that will stop soon. |
| active_count = 0 # Number of currently active threads |
| logger = logger |
| queue_logger = queue_logger |
| |
| def __init__(self): |
| self.threads = set() |
| self.queue = deque() |
| self.lock = threading.Lock() |
| self.queue_cv = threading.Condition(self.lock) |
| self.thread_exit_cv = threading.Condition(self.lock) |
| |
| def start_new_thread(self, target, thread_no): |
| t = threading.Thread( |
| target=target, name="waitress-{}".format(thread_no), args=(thread_no,) |
| ) |
| t.daemon = True |
| t.start() |
| |
| def handler_thread(self, thread_no): |
| while True: |
| with self.lock: |
| while not self.queue and self.stop_count == 0: |
| # Mark ourselves as idle before waiting to be |
| # woken up, then we will once again be active |
| self.active_count -= 1 |
| self.queue_cv.wait() |
| self.active_count += 1 |
| |
| if self.stop_count > 0: |
| self.active_count -= 1 |
| self.stop_count -= 1 |
| self.threads.discard(thread_no) |
| self.thread_exit_cv.notify() |
| break |
| |
| task = self.queue.popleft() |
| try: |
| task.service() |
| except BaseException: |
| self.logger.exception("Exception when servicing %r", task) |
| |
| def set_thread_count(self, count): |
| with self.lock: |
| threads = self.threads |
| thread_no = 0 |
| running = len(threads) - self.stop_count |
| while running < count: |
| # Start threads. |
| while thread_no in threads: |
| thread_no = thread_no + 1 |
| threads.add(thread_no) |
| running += 1 |
| self.start_new_thread(self.handler_thread, thread_no) |
| self.active_count += 1 |
| thread_no = thread_no + 1 |
| if running > count: |
| # Stop threads. |
| self.stop_count += running - count |
| self.queue_cv.notify_all() |
| |
| def add_task(self, task): |
| with self.lock: |
| self.queue.append(task) |
| self.queue_cv.notify() |
| queue_size = len(self.queue) |
| idle_threads = len(self.threads) - self.stop_count - self.active_count |
| if queue_size > idle_threads: |
| self.queue_logger.warning( |
| "Task queue depth is %d", queue_size - idle_threads |
| ) |
| |
| def shutdown(self, cancel_pending=True, timeout=5): |
| self.set_thread_count(0) |
| # Ensure the threads shut down. |
| threads = self.threads |
| expiration = time.time() + timeout |
| with self.lock: |
| while threads: |
| if time.time() >= expiration: |
| self.logger.warning("%d thread(s) still running", len(threads)) |
| break |
| self.thread_exit_cv.wait(0.1) |
| if cancel_pending: |
| # Cancel remaining tasks. |
| queue = self.queue |
| if len(queue) > 0: |
| self.logger.warning("Canceling %d pending task(s)", len(queue)) |
| while queue: |
| task = queue.popleft() |
| task.cancel() |
| self.queue_cv.notify_all() |
| return True |
| return False |
| |
| |
| class Task: |
| close_on_finish = False |
| status = "200 OK" |
| wrote_header = False |
| start_time = 0 |
| content_length = None |
| content_bytes_written = 0 |
| logged_write_excess = False |
| logged_write_no_body = False |
| complete = False |
| chunked_response = False |
| logger = logger |
| |
| def __init__(self, channel, request): |
| self.channel = channel |
| self.request = request |
| self.response_headers = [] |
| version = request.version |
| if version not in ("1.0", "1.1"): |
| # fall back to a version we support. |
| version = "1.0" |
| self.version = version |
| |
| def service(self): |
| try: |
| self.start() |
| self.execute() |
| self.finish() |
| except OSError: |
| self.close_on_finish = True |
| if self.channel.adj.log_socket_errors: |
| raise |
| |
| @property |
| def has_body(self): |
| return not ( |
| self.status.startswith("1") |
| or self.status.startswith("204") |
| or self.status.startswith("304") |
| ) |
| |
| def build_response_header(self): |
| version = self.version |
| # Figure out whether the connection should be closed. |
| connection = self.request.headers.get("CONNECTION", "").lower() |
| response_headers = [] |
| content_length_header = None |
| date_header = None |
| server_header = None |
| connection_close_header = None |
| |
| for (headername, headerval) in self.response_headers: |
| headername = "-".join([x.capitalize() for x in headername.split("-")]) |
| |
| if headername == "Content-Length": |
| if self.has_body: |
| content_length_header = headerval |
| else: |
| continue # pragma: no cover |
| |
| if headername == "Date": |
| date_header = headerval |
| |
| if headername == "Server": |
| server_header = headerval |
| |
| if headername == "Connection": |
| connection_close_header = headerval.lower() |
| # replace with properly capitalized version |
| response_headers.append((headername, headerval)) |
| |
| if ( |
| content_length_header is None |
| and self.content_length is not None |
| and self.has_body |
| ): |
| content_length_header = str(self.content_length) |
| response_headers.append(("Content-Length", content_length_header)) |
| |
| def close_on_finish(): |
| if connection_close_header is None: |
| response_headers.append(("Connection", "close")) |
| self.close_on_finish = True |
| |
| if version == "1.0": |
| if connection == "keep-alive": |
| if not content_length_header: |
| close_on_finish() |
| else: |
| response_headers.append(("Connection", "Keep-Alive")) |
| else: |
| close_on_finish() |
| |
| elif version == "1.1": |
| if connection == "close": |
| close_on_finish() |
| |
| if not content_length_header: |
| # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length |
| # for any response with a status code of 1xx, 204 or 304. |
| |
| if self.has_body: |
| response_headers.append(("Transfer-Encoding", "chunked")) |
| self.chunked_response = True |
| |
| if not self.close_on_finish: |
| close_on_finish() |
| |
| # under HTTP 1.1 keep-alive is default, no need to set the header |
| else: |
| raise AssertionError("neither HTTP/1.0 or HTTP/1.1") |
| |
| # Set the Server and Date field, if not yet specified. This is needed |
| # if the server is used as a proxy. |
| ident = self.channel.server.adj.ident |
| |
| if not server_header: |
| if ident: |
| response_headers.append(("Server", ident)) |
| else: |
| response_headers.append(("Via", ident or "waitress")) |
| |
| if not date_header: |
| response_headers.append(("Date", build_http_date(self.start_time))) |
| |
| self.response_headers = response_headers |
| |
| first_line = "HTTP/%s %s" % (self.version, self.status) |
| # NB: sorting headers needs to preserve same-named-header order |
| # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; |
| # rely on stable sort to keep relative position of same-named headers |
| next_lines = [ |
| "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) |
| ] |
| lines = [first_line] + next_lines |
| res = "%s\r\n\r\n" % "\r\n".join(lines) |
| |
| return res.encode("latin-1") |
| |
| def remove_content_length_header(self): |
| response_headers = [] |
| |
| for header_name, header_value in self.response_headers: |
| if header_name.lower() == "content-length": |
| continue # pragma: nocover |
| response_headers.append((header_name, header_value)) |
| |
| self.response_headers = response_headers |
| |
| def start(self): |
| self.start_time = time.time() |
| |
| def finish(self): |
| if not self.wrote_header: |
| self.write(b"") |
| if self.chunked_response: |
| # not self.write, it will chunk it! |
| self.channel.write_soon(b"0\r\n\r\n") |
| |
| def write(self, data): |
| if not self.complete: |
| raise RuntimeError("start_response was not called before body written") |
| channel = self.channel |
| if not self.wrote_header: |
| rh = self.build_response_header() |
| channel.write_soon(rh) |
| self.wrote_header = True |
| |
| if data and self.has_body: |
| towrite = data |
| cl = self.content_length |
| if self.chunked_response: |
| # use chunked encoding response |
| towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n" |
| towrite += data + b"\r\n" |
| elif cl is not None: |
| towrite = data[: cl - self.content_bytes_written] |
| self.content_bytes_written += len(towrite) |
| if towrite != data and not self.logged_write_excess: |
| self.logger.warning( |
| "application-written content exceeded the number of " |
| "bytes specified by Content-Length header (%s)" % cl |
| ) |
| self.logged_write_excess = True |
| if towrite: |
| channel.write_soon(towrite) |
| elif data: |
| # Cheat, and tell the application we have written all of the bytes, |
| # even though the response shouldn't have a body and we are |
| # ignoring it entirely. |
| self.content_bytes_written += len(data) |
| |
| if not self.logged_write_no_body: |
| self.logger.warning( |
| "application-written content was ignored due to HTTP " |
| "response that may not contain a message-body: (%s)" % self.status |
| ) |
| self.logged_write_no_body = True |
| |
| |
| class ErrorTask(Task): |
| """An error task produces an error response""" |
| |
| complete = True |
| |
| def execute(self): |
| e = self.request.error |
| status, headers, body = e.to_response() |
| self.status = status |
| self.response_headers.extend(headers) |
| # We need to explicitly tell the remote client we are closing the |
| # connection, because self.close_on_finish is set, and we are going to |
| # slam the door in the clients face. |
| self.response_headers.append(("Connection", "close")) |
| self.close_on_finish = True |
| self.content_length = len(body) |
| self.write(body.encode("latin-1")) |
| |
| |
| class WSGITask(Task): |
| """A WSGI task produces a response from a WSGI application.""" |
| |
| environ = None |
| |
| def execute(self): |
| environ = self.get_environment() |
| |
| def start_response(status, headers, exc_info=None): |
| if self.complete and not exc_info: |
| raise AssertionError( |
| "start_response called a second time without providing exc_info." |
| ) |
| if exc_info: |
| try: |
| if self.wrote_header: |
| # higher levels will catch and handle raised exception: |
| # 1. "service" method in task.py |
| # 2. "service" method in channel.py |
| # 3. "handler_thread" method in task.py |
| raise exc_info[1] |
| else: |
| # As per WSGI spec existing headers must be cleared |
| self.response_headers = [] |
| finally: |
| exc_info = None |
| |
| self.complete = True |
| |
| if not status.__class__ is str: |
| raise AssertionError("status %s is not a string" % status) |
| if "\n" in status or "\r" in status: |
| raise ValueError( |
| "carriage return/line feed character present in status" |
| ) |
| |
| self.status = status |
| |
| # Prepare the headers for output |
| for k, v in headers: |
| if not k.__class__ is str: |
| raise AssertionError( |
| "Header name %r is not a string in %r" % (k, (k, v)) |
| ) |
| if not v.__class__ is str: |
| raise AssertionError( |
| "Header value %r is not a string in %r" % (v, (k, v)) |
| ) |
| |
| if "\n" in v or "\r" in v: |
| raise ValueError( |
| "carriage return/line feed character present in header value" |
| ) |
| if "\n" in k or "\r" in k: |
| raise ValueError( |
| "carriage return/line feed character present in header name" |
| ) |
| |
| kl = k.lower() |
| if kl == "content-length": |
| self.content_length = int(v) |
| elif kl in hop_by_hop: |
| raise AssertionError( |
| '%s is a "hop-by-hop" header; it cannot be used by ' |
| "a WSGI application (see PEP 3333)" % k |
| ) |
| |
| self.response_headers.extend(headers) |
| |
| # Return a method used to write the response data. |
| return self.write |
| |
| # Call the application to handle the request and write a response |
| app_iter = self.channel.server.application(environ, start_response) |
| |
| can_close_app_iter = True |
| try: |
| if app_iter.__class__ is ReadOnlyFileBasedBuffer: |
| cl = self.content_length |
| size = app_iter.prepare(cl) |
| if size: |
| if cl != size: |
| if cl is not None: |
| self.remove_content_length_header() |
| self.content_length = size |
| self.write(b"") # generate headers |
| # if the write_soon below succeeds then the channel will |
| # take over closing the underlying file via the channel's |
| # _flush_some or handle_close so we intentionally avoid |
| # calling close in the finally block |
| self.channel.write_soon(app_iter) |
| can_close_app_iter = False |
| return |
| |
| first_chunk_len = None |
| for chunk in app_iter: |
| if first_chunk_len is None: |
| first_chunk_len = len(chunk) |
| # Set a Content-Length header if one is not supplied. |
| # start_response may not have been called until first |
| # iteration as per PEP, so we must reinterrogate |
| # self.content_length here |
| if self.content_length is None: |
| app_iter_len = None |
| if hasattr(app_iter, "__len__"): |
| app_iter_len = len(app_iter) |
| if app_iter_len == 1: |
| self.content_length = first_chunk_len |
| # transmit headers only after first iteration of the iterable |
| # that returns a non-empty bytestring (PEP 3333) |
| if chunk: |
| self.write(chunk) |
| |
| cl = self.content_length |
| if cl is not None: |
| if self.content_bytes_written != cl: |
| # close the connection so the client isn't sitting around |
| # waiting for more data when there are too few bytes |
| # to service content-length |
| self.close_on_finish = True |
| if self.request.command != "HEAD": |
| self.logger.warning( |
| "application returned too few bytes (%s) " |
| "for specified Content-Length (%s) via app_iter" |
| % (self.content_bytes_written, cl), |
| ) |
| finally: |
| if can_close_app_iter and hasattr(app_iter, "close"): |
| app_iter.close() |
| |
| def get_environment(self): |
| """Returns a WSGI environment.""" |
| environ = self.environ |
| if environ is not None: |
| # Return the cached copy. |
| return environ |
| |
| request = self.request |
| path = request.path |
| channel = self.channel |
| server = channel.server |
| url_prefix = server.adj.url_prefix |
| |
| if path.startswith("/"): |
| # strip extra slashes at the beginning of a path that starts |
| # with any number of slashes |
| path = "/" + path.lstrip("/") |
| |
| if url_prefix: |
| # NB: url_prefix is guaranteed by the configuration machinery to |
| # be either the empty string or a string that starts with a single |
| # slash and ends without any slashes |
| if path == url_prefix: |
| # if the path is the same as the url prefix, the SCRIPT_NAME |
| # should be the url_prefix and PATH_INFO should be empty |
| path = "" |
| else: |
| # if the path starts with the url prefix plus a slash, |
| # the SCRIPT_NAME should be the url_prefix and PATH_INFO should |
| # the value of path from the slash until its end |
| url_prefix_with_trailing_slash = url_prefix + "/" |
| if path.startswith(url_prefix_with_trailing_slash): |
| path = path[len(url_prefix) :] |
| |
| environ = { |
| "REMOTE_ADDR": channel.addr[0], |
| # Nah, we aren't actually going to look up the reverse DNS for |
| # REMOTE_ADDR, but we will happily set this environment variable |
| # for the WSGI application. Spec says we can just set this to |
| # REMOTE_ADDR, so we do. |
| "REMOTE_HOST": channel.addr[0], |
| # try and set the REMOTE_PORT to something useful, but maybe None |
| "REMOTE_PORT": str(channel.addr[1]), |
| "REQUEST_METHOD": request.command.upper(), |
| "SERVER_PORT": str(server.effective_port), |
| "SERVER_NAME": server.server_name, |
| "SERVER_SOFTWARE": server.adj.ident, |
| "SERVER_PROTOCOL": "HTTP/%s" % self.version, |
| "SCRIPT_NAME": url_prefix, |
| "PATH_INFO": path, |
| "QUERY_STRING": request.query, |
| "wsgi.url_scheme": request.url_scheme, |
| # the following environment variables are required by the WSGI spec |
| "wsgi.version": (1, 0), |
| # apps should use the logging module |
| "wsgi.errors": sys.stderr, |
| "wsgi.multithread": True, |
| "wsgi.multiprocess": False, |
| "wsgi.run_once": False, |
| "wsgi.input": request.get_body_stream(), |
| "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, |
| "wsgi.input_terminated": True, # wsgi.input is EOF terminated |
| } |
| |
| for key, value in dict(request.headers).items(): |
| value = value.strip() |
| mykey = rename_headers.get(key, None) |
| if mykey is None: |
| mykey = "HTTP_" + key |
| if mykey not in environ: |
| environ[mykey] = value |
| |
| # Insert a callable into the environment that allows the application to |
| # check if the client disconnected. Only works with |
| # channel_request_lookahead larger than 0. |
| environ["waitress.client_disconnected"] = self.channel.check_client_disconnected |
| |
| # cache the environ for this request |
| self.environ = environ |
| return environ |