Allow tasks to notice if client disconnected

This inserts a callable `waitress.client_disconnected` into the
environment that allows the task to check if the client disconnected
while waiting for the response at strategic points in the execution,
allowing to cancel the operation.

It requires setting the new adjustment `channel_request_lookahead` to a
value larger than 0, which continues to read requests from a channel
even if a request is already being processed on that channel, up to the
given count, since a client disconnect is detected by reading from a
readable socket and receiving an empty result.
diff --git a/CHANGES.txt b/CHANGES.txt
index f4d1acc..894ff94 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,5 +1,19 @@
 2.0.0 (unreleased)
 ------------------
+- Allow tasks to notice if the client disconnected.
+
+  This inserts a callable `waitress.client_disconnected` into the environment
+  that allows the task to check if the client disconnected while waiting for
+  the response at strategic points in the execution and to cancel the
+  operation.
+
+  It requires setting the new adjustment `channel_request_lookahead` to a value
+  larger than 0, which continues to read requests from a channel even if a
+  request is already being processed on that channel, up to the given count,
+  since a client disconnect is detected by reading from a readable socket and
+  receiving an empty result.
+
+  See https://github.com/Pylons/waitress/pull/310
 
 - Drop Python 2.7 support
 
diff --git a/src/waitress/adjustments.py b/src/waitress/adjustments.py
index 45ac41b..42d2bc0 100644
--- a/src/waitress/adjustments.py
+++ b/src/waitress/adjustments.py
@@ -135,6 +135,7 @@
         ("unix_socket", str),
         ("unix_socket_perms", asoctal),
         ("sockets", as_socket_list),
+        ("channel_request_lookahead", int),
     )
 
     _param_map = dict(_params)
@@ -280,6 +281,13 @@
     # be used for e.g. socket activation
     sockets = []
 
+    # By setting this to a value larger than zero, each channel stays readable
+    # and continues to read requests from the client even if a request is still
+    # running, until the number of buffered requests exceeds this value.
+    # This allows detecting if a client closed the connection while its request
+    # is being processed.
+    channel_request_lookahead = 0
+
     def __init__(self, **kw):
 
         if "listen" in kw and ("host" in kw or "port" in kw):
diff --git a/src/waitress/channel.py b/src/waitress/channel.py
index 65bc87f..296a16a 100644
--- a/src/waitress/channel.py
+++ b/src/waitress/channel.py
@@ -40,11 +40,11 @@
     error_task_class = ErrorTask
     parser_class = HTTPRequestParser
 
-    request = None  # A request parser instance
+    # A request that has not been received yet completely is stored here
+    request = None
     last_activity = 0  # Time of last activity
     will_close = False  # set to True to close the socket.
     close_when_flushed = False  # set to True to close the socket when flushed
-    requests = ()  # currently pending requests
     sent_continue = False  # used as a latch after sending 100 continue
     total_outbufs_len = 0  # total bytes ready to send
     current_outbuf_count = 0  # total bytes written to current outbuf
@@ -60,8 +60,9 @@
         self.creation_time = self.last_activity = time.time()
         self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
 
-        # task_lock used to push/pop requests
-        self.task_lock = threading.Lock()
+        # requests_lock used to push/pop requests and modify the request that is
+        # currently being created
+        self.requests_lock = threading.Lock()
         # outbuf_lock used to access any outbuf (expected to use an RLock)
         self.outbuf_lock = threading.Condition()
 
@@ -69,6 +70,15 @@
 
         # Don't let wasyncore.dispatcher throttle self.addr on us.
         self.addr = addr
+        self.requests = []
+
+    def check_client_disconnected(self):
+        """
+        This method is inserted into the environment of any created task so it
+        may occasionally check if the client has disconnected and interrupt
+        execution.
+        """
+        return not self.connected
 
     def writable(self):
         # if there's data in the out buffer or we've been instructed to close
@@ -125,18 +135,18 @@
             self.handle_close()
 
     def readable(self):
-        # We might want to create a new task.  We can only do this if:
+        # We might want to read more requests. We can only do this if:
         # 1. We're not already about to close the connection.
         # 2. We're not waiting to flush remaining data before closing the
         #    connection
-        # 3. There's no already currently running task(s).
+        # 3. There are not too many tasks already queued
         # 4. There's no data in the output buffer that needs to be sent
         #    before we potentially create a new task.
 
         return not (
             self.will_close
             or self.close_when_flushed
-            or self.requests
+            or len(self.requests) > self.adj.channel_request_lookahead
             or self.total_outbufs_len
         )
 
@@ -153,57 +163,69 @@
         if data:
             self.last_activity = time.time()
             self.received(data)
+        else:
+            # Client disconnected.
+            self.connected = False
+
+    def send_continue(self):
+        """
+        Send a 100-Continue header to the client. This is either called from
+        receive (if no requests are running and the client expects it) or at
+        the end of service (if no more requests are queued and a request has
+        been read partially that expects it).
+        """
+        self.request.expect_continue = False
+        outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n"
+        num_bytes = len(outbuf_payload)
+        with self.outbuf_lock:
+            self.outbufs[-1].append(outbuf_payload)
+            self.current_outbuf_count += num_bytes
+            self.total_outbufs_len += num_bytes
+            self.sent_continue = True
+            self._flush_some()
+        self.request.completed = False
 
     def received(self, data):
         """
         Receives input asynchronously and assigns one or more requests to the
         channel.
         """
-        # Preconditions: there's no task(s) already running
-        request = self.request
-        requests = []
-
         if not data:
             return False
 
-        while data:
-            if request is None:
-                request = self.parser_class(self.adj)
-            n = request.received(data)
+        with self.requests_lock:
+            while data:
+                if self.request is None:
+                    self.request = self.parser_class(self.adj)
+                n = self.request.received(data)
 
-            if request.expect_continue and request.headers_finished:
-                # guaranteed by parser to be a 1.1 request
-                request.expect_continue = False
+                # if there are requests queued, we can not send the continue
+                # header yet since the responses need to be kept in order
+                if (
+                    self.request.expect_continue
+                    and self.request.headers_finished
+                    and not self.requests
+                    and not self.sent_continue
+                ):
+                    self.send_continue()
 
-                if not self.sent_continue:
-                    # there's no current task, so we don't need to try to
-                    # lock the outbuf to append to it.
-                    outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n"
-                    num_bytes = len(outbuf_payload)
-                    self.outbufs[-1].append(outbuf_payload)
-                    self.current_outbuf_count += num_bytes
-                    self.total_outbufs_len += num_bytes
-                    self.sent_continue = True
-                    self._flush_some()
-                    request.completed = False
+                if self.request.completed:
+                    # The request (with the body) is ready to use.
+                    self.sent_continue = False
 
-            if request.completed:
-                # The request (with the body) is ready to use.
-                self.request = None
+                    if not self.request.empty:
+                        self.requests.append(self.request)
+                        if len(self.requests) == 1:
+                            # self.requests was empty before so the main thread
+                            # is in charge of starting the task. Otherwise,
+                            # service() will add a new task after each request
+                            # has been processed
+                            self.server.add_task(self)
+                    self.request = None
 
-                if not request.empty:
-                    requests.append(request)
-                request = None
-            else:
-                self.request = request
-
-            if n >= len(data):
-                break
-            data = data[n:]
-
-        if requests:
-            self.requests = requests
-            self.server.add_task(self)
+                if n >= len(data):
+                    break
+                data = data[n:]
 
         return True
 
@@ -360,88 +382,101 @@
                     self.outbuf_lock.wait()
 
     def service(self):
-        """Execute all pending requests """
-        with self.task_lock:
-            while self.requests:
-                request = self.requests[0]
+        """Execute one request. If there are more, we add another task to the
+        server at the end."""
 
-                if request.error:
-                    task = self.error_task_class(self, request)
+        request = self.requests[0]
+
+        if request.error:
+            task = self.error_task_class(self, request)
+        else:
+            task = self.task_class(self, request)
+
+        try:
+            if self.connected:
+                task.service()
+            else:
+                task.close_on_finish = True
+        except ClientDisconnected:
+            self.logger.info("Client disconnected while serving %s" % task.request.path)
+            task.close_on_finish = True
+        except Exception:
+            self.logger.exception("Exception while serving %s" % task.request.path)
+
+            if not task.wrote_header:
+                if self.adj.expose_tracebacks:
+                    body = traceback.format_exc()
                 else:
-                    task = self.task_class(self, request)
+                    body = "The server encountered an unexpected internal server error"
+                req_version = request.version
+                req_headers = request.headers
+                err_request = self.parser_class(self.adj)
+                err_request.error = InternalServerError(body)
+                # copy some original request attributes to fulfill
+                # HTTP 1.1 requirements
+                err_request.version = req_version
                 try:
-                    task.service()
+                    err_request.headers["CONNECTION"] = req_headers["CONNECTION"]
+                except KeyError:
+                    pass
+                task = self.error_task_class(self, err_request)
+                try:
+                    task.service()  # must not fail
                 except ClientDisconnected:
-                    self.logger.info(
-                        "Client disconnected while serving %s" % task.request.path
-                    )
                     task.close_on_finish = True
-                except Exception:
-                    self.logger.exception(
-                        "Exception while serving %s" % task.request.path
-                    )
+            else:
+                task.close_on_finish = True
 
-                    if not task.wrote_header:
-                        if self.adj.expose_tracebacks:
-                            body = traceback.format_exc()
-                        else:
-                            body = (
-                                "The server encountered an unexpected "
-                                "internal server error"
-                            )
-                        req_version = request.version
-                        req_headers = request.headers
-                        request = self.parser_class(self.adj)
-                        request.error = InternalServerError(body)
-                        # copy some original request attributes to fulfill
-                        # HTTP 1.1 requirements
-                        request.version = req_version
-                        try:
-                            request.headers["CONNECTION"] = req_headers["CONNECTION"]
-                        except KeyError:
-                            pass
-                        task = self.error_task_class(self, request)
-                        try:
-                            task.service()  # must not fail
-                        except ClientDisconnected:
-                            task.close_on_finish = True
-                    else:
-                        task.close_on_finish = True
-                # we cannot allow self.requests to drop to empty til
-                # here; otherwise the mainloop gets confused
+        if task.close_on_finish:
+            with self.requests_lock:
+                self.close_when_flushed = True
 
-                if task.close_on_finish:
-                    self.close_when_flushed = True
-
-                    for request in self.requests:
-                        request.close()
-                    self.requests = []
-                else:
-                    # before processing a new request, ensure there is not too
-                    # much data in the outbufs waiting to be flushed
-                    # NB: currently readable() returns False while we are
-                    # flushing data so we know no new requests will come in
-                    # that we need to account for, otherwise it'd be better
-                    # to do this check at the start of the request instead of
-                    # at the end to account for consecutive service() calls
-
-                    if len(self.requests) > 1:
-                        self._flush_outbufs_below_high_watermark()
-
-                    # this is a little hacky but basically it's forcing the
-                    # next request to create a new outbuf to avoid sharing
-                    # outbufs across requests which can cause outbufs to
-                    # not be deallocated regularly when a connection is open
-                    # for a long time
-
-                    if self.current_outbuf_count > 0:
-                        self.current_outbuf_count = self.adj.outbuf_high_watermark
-
-                    request = self.requests.pop(0)
+                for request in self.requests:
                     request.close()
+                self.requests = []
+        else:
+            # before processing a new request, ensure there is not too
+            # much data in the outbufs waiting to be flushed
+            # NB: currently readable() returns False while we are
+            # flushing data so we know no new requests will come in
+            # that we need to account for, otherwise it'd be better
+            # to do this check at the start of the request instead of
+            # at the end to account for consecutive service() calls
+
+            if len(self.requests) > 1:
+                self._flush_outbufs_below_high_watermark()
+
+            # this is a little hacky but basically it's forcing the
+            # next request to create a new outbuf to avoid sharing
+            # outbufs across requests which can cause outbufs to
+            # not be deallocated regularly when a connection is open
+            # for a long time
+
+            if self.current_outbuf_count > 0:
+                self.current_outbuf_count = self.adj.outbuf_high_watermark
+
+            request.close()
+
+            # Add new task to process the next request
+            with self.requests_lock:
+                self.requests.pop(0)
+                if self.connected and self.requests:
+                    self.server.add_task(self)
+                elif (
+                    self.connected
+                    and self.request is not None
+                    and self.request.expect_continue
+                    and self.request.headers_finished
+                    and not self.sent_continue
+                ):
+                    # A request waits for a signal to continue, but we could
+                    # not send it until now because requests were being
+                    # processed and the output needs to be kept in order
+                    self.send_continue()
 
         if self.connected:
             self.server.pull_trigger()
+
         self.last_activity = time.time()
 
     def cancel(self):
diff --git a/src/waitress/runner.py b/src/waitress/runner.py
index 4fb3e6b..c23ca0e 100644
--- a/src/waitress/runner.py
+++ b/src/waitress/runner.py
@@ -169,6 +169,12 @@
         The use_poll argument passed to ``asyncore.loop()``. Helps overcome
         open file descriptors limit. Default is False.
 
+    --channel-request-lookahead=INT
+        Allows channels to stay readable and buffer more requests up to the
+        given maximum even if a request is already being processed. This allows
+        detecting if a client closed the connection while its request is being
+        processed. Default is 0.
+
 """
 
 RUNNER_PATTERN = re.compile(
diff --git a/src/waitress/task.py b/src/waitress/task.py
index 3a7cf17..2ac8f4c 100644
--- a/src/waitress/task.py
+++ b/src/waitress/task.py
@@ -560,6 +560,11 @@
             if mykey not in environ:
                 environ[mykey] = value
 
+        # Insert a callable into the environment that allows the application to
+        # check if the client disconnected. Only works with
+        # channel_request_lookahead larger than 0.
+        environ["waitress.client_disconnected"] = self.channel.check_client_disconnected
+
         # cache the environ for this request
         self.environ = environ
         return environ
diff --git a/tests/test_channel.py b/tests/test_channel.py
index df3d450..d86dbbe 100644
--- a/tests/test_channel.py
+++ b/tests/test_channel.py
@@ -1,6 +1,8 @@
 import io
 import unittest
 
+import pytest
+
 
 class TestHTTPChannel(unittest.TestCase):
     def _makeOne(self, sock, addr, adj, map=None):
@@ -173,7 +175,7 @@
 
     def test_readable_with_requests(self):
         inst, sock, map = self._makeOneWithMap()
-        inst.requests = True
+        inst.requests = [True]
         self.assertEqual(inst.readable(), False)
 
     def test_handle_read_no_error(self):
@@ -189,8 +191,6 @@
         self.assertEqual(L, [b"abc"])
 
     def test_handle_read_error(self):
-        import socket
-
         inst, sock, map = self._makeOneWithMap()
         inst.will_close = False
 
@@ -439,7 +439,7 @@
         preq.completed = False
         preq.empty = True
         inst.received(b"GET / HTTP/1.1\r\n\r\n")
-        self.assertEqual(inst.requests, ())
+        self.assertEqual(inst.requests, [])
         self.assertEqual(inst.server.tasks, [])
 
     def test_received_preq_completed_empty(self):
@@ -525,14 +525,6 @@
         self.assertEqual(inst.sent_continue, True)
         self.assertEqual(preq.completed, False)
 
-    def test_service_no_requests(self):
-        inst, sock, map = self._makeOneWithMap()
-        inst.requests = []
-        inst.service()
-        self.assertEqual(inst.requests, [])
-        self.assertTrue(inst.server.trigger_pulled)
-        self.assertTrue(inst.last_activity)
-
     def test_service_with_one_request(self):
         inst, sock, map = self._makeOneWithMap()
         request = DummyRequest()
@@ -561,6 +553,7 @@
         inst.task_class = DummyTaskClass()
         inst.requests = [request1, request2]
         inst.service()
+        inst.service()
         self.assertEqual(inst.requests, [])
         self.assertTrue(request1.serviced)
         self.assertTrue(request2.serviced)
@@ -705,6 +698,137 @@
         self.assertEqual(inst.requests, [])
 
 
+class TestHTTPChannelLookahead(TestHTTPChannel):
+    def app_check_disconnect(self, environ, start_response):
+        """
+        Application that checks for client disconnection every
+        second for up to two seconds.
+        """
+        import time
+
+        if hasattr(self, "app_started"):
+            self.app_started.set()
+
+        try:
+            request_body_size = int(environ.get("CONTENT_LENGTH", 0))
+        except ValueError:
+            request_body_size = 0
+        self.request_body = environ["wsgi.input"].read(request_body_size)
+
+        self.disconnect_detected = False
+        check = environ["waitress.client_disconnected"]
+        if environ["PATH_INFO"] == "/sleep":
+            for i in range(3):
+                if i != 0:
+                    time.sleep(1)
+                if check():
+                    self.disconnect_detected = True
+                    break
+
+        body = b"finished"
+        cl = str(len(body))
+        start_response(
+            "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]
+        )
+        return [body]
+
+    def _make_app_with_lookahead(self):
+        """
+        Setup a channel with lookahead and store it and the socket in self
+        """
+        adj = DummyAdjustments()
+        adj.channel_request_lookahead = 5
+        channel, sock, map = self._makeOneWithMap(adj=adj)
+        channel.server.application = self.app_check_disconnect
+
+        self.channel = channel
+        self.sock = sock
+
+    def _send(self, *lines):
+        """
+        Send lines through the socket with correct line endings
+        """
+        self.sock.send("".join(line + "\r\n" for line in lines).encode("ascii"))
+
+    def test_client_disconnect(self, close_before_start=False):
+        """Disconnect the socket after starting the task."""
+        import threading
+
+        self._make_app_with_lookahead()
+        self._send(
+            "GET /sleep HTTP/1.1",
+            "Host: localhost:8080",
+            "",
+        )
+        self.assertTrue(self.channel.readable())
+        self.channel.handle_read()
+        self.assertEqual(len(self.channel.server.tasks), 1)
+        self.app_started = threading.Event()
+        self.disconnect_detected = False
+        thread = threading.Thread(target=self.channel.server.tasks[0].service)
+
+        if not close_before_start:
+            thread.start()
+            self.assertTrue(self.app_started.wait(timeout=5))
+
+        # Close the socket, check that the channel is still readable due to the
+        # lookahead and read it, which marks the channel as closed.
+        self.sock.close()
+        self.assertTrue(self.channel.readable())
+        self.channel.handle_read()
+
+        if close_before_start:
+            thread.start()
+
+        thread.join()
+
+        if close_before_start:
+            self.assertFalse(self.app_started.is_set())
+        else:
+            self.assertTrue(self.disconnect_detected)
+
+    def test_client_disconnect_immediate(self):
+        """
+        The same test, but this time we close the socket even before processing
+        started. The app should not be executed.
+        """
+        self.test_client_disconnect(close_before_start=True)
+
+    def test_lookahead_continue(self):
+        """
+        Send two requests to a channel with lookahead and use an
+        expect-continue on the second one, making sure the responses still come
+        in the right order.
+        """
+        self._make_app_with_lookahead()
+        self._send(
+            "POST / HTTP/1.1",
+            "Host: localhost:8080",
+            "Content-Length: 1",
+            "",
+            "x",
+            "POST / HTTP/1.1",
+            "Host: localhost:8080",
+            "Content-Length: 1",
+            "Expect: 100-continue",
+            "",
+        )
+        self.channel.handle_read()
+        self.assertEqual(len(self.channel.requests), 1)
+        self.channel.server.tasks[0].service()
+        data = self.sock.recv(256).decode("ascii")
+        self.assertTrue(data.endswith("HTTP/1.1 100 Continue\r\n\r\n"))
+
+        self.sock.send(b"x")
+        self.channel.handle_read()
+        self.assertEqual(len(self.channel.requests), 1)
+        self.channel.server.tasks[0].service()
+        self.channel._flush_some()
+        data = self.sock.recv(256).decode("ascii")
+        self.assertEqual(data.split("\r\n")[-1], "finished")
+        self.assertEqual(self.request_body, b"x")
+
+
 class DummySock:
     blocking = False
     closed = False
@@ -731,6 +855,11 @@
         self.sent += data
         return len(data)
 
+    def recv(self, buffer_size):
+        result = self.sent[:buffer_size]
+        self.sent = self.sent[buffer_size:]
+        return result
+
 
 class DummyLock:
     notified = False
@@ -796,11 +925,16 @@
     expose_tracebacks = True
     ident = "waitress"
     max_request_header_size = 10000
+    url_prefix = ""
+    channel_request_lookahead = 0
+    max_request_body_size = 1048576
 
 
 class DummyServer:
     trigger_pulled = False
     adj = DummyAdjustments()
+    effective_port = 8080
+    server_name = ""
 
     def __init__(self):
         self.tasks = []
diff --git a/tests/test_task.py b/tests/test_task.py
index de800fb..ea71e02 100644
--- a/tests/test_task.py
+++ b/tests/test_task.py
@@ -802,6 +802,7 @@
                 "SERVER_PORT",
                 "SERVER_PROTOCOL",
                 "SERVER_SOFTWARE",
+                "waitress.client_disconnected",
                 "wsgi.errors",
                 "wsgi.file_wrapper",
                 "wsgi.input",
@@ -958,6 +959,10 @@
     creation_time = 0
     addr = ("127.0.0.1", 39830)
 
+    def check_client_disconnected(self):
+        # For now, until we have tests handling this feature
+        return False
+
     def __init__(self, server=None):
         if server is None:
             server = DummyServer()