Merge pull request #10022 from kpayson64/resource_exauhsted

Add max_requests argument to server
diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py
index a4481b2..4960df3 100644
--- a/src/python/grpcio/grpc/__init__.py
+++ b/src/python/grpcio/grpc/__init__.py
@@ -1273,7 +1273,10 @@
                             credentials._credentials)
 
 
-def server(thread_pool, handlers=None, options=None):
+def server(thread_pool,
+           handlers=None,
+           options=None,
+           maximum_concurrent_rpcs=None):
     """Creates a Server with which RPCs can be serviced.
 
   Args:
@@ -1286,13 +1289,17 @@
       returned Server is started.
     options: A sequence of string-value pairs according to which to configure
       the created server.
+    maximum_concurrent_rpcs: The maximum number of concurrent RPCs this server
+      will service before returning status RESOURCE_EXHAUSTED, or None to
+      indicate no limit.
 
   Returns:
     A Server with which RPCs can be serviced.
   """
     from grpc import _server  # pylint: disable=cyclic-import
     return _server.Server(thread_pool, () if handlers is None else handlers, ()
-                          if options is None else options)
+                          if options is None else options,
+                          maximum_concurrent_rpcs)
 
 
 ###################################  __all__  #################################
diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py
index 84e096d..47838c2 100644
--- a/src/python/grpcio/grpc/_server.py
+++ b/src/python/grpcio/grpc/_server.py
@@ -504,37 +504,37 @@
 def _handle_unary_unary(rpc_event, state, method_handler, thread_pool):
     unary_request = _unary_request(rpc_event, state,
                                    method_handler.request_deserializer)
-    thread_pool.submit(_unary_response_in_pool, rpc_event, state,
-                       method_handler.unary_unary, unary_request,
-                       method_handler.request_deserializer,
-                       method_handler.response_serializer)
+    return thread_pool.submit(_unary_response_in_pool, rpc_event, state,
+                              method_handler.unary_unary, unary_request,
+                              method_handler.request_deserializer,
+                              method_handler.response_serializer)
 
 
 def _handle_unary_stream(rpc_event, state, method_handler, thread_pool):
     unary_request = _unary_request(rpc_event, state,
                                    method_handler.request_deserializer)
-    thread_pool.submit(_stream_response_in_pool, rpc_event, state,
-                       method_handler.unary_stream, unary_request,
-                       method_handler.request_deserializer,
-                       method_handler.response_serializer)
+    return thread_pool.submit(_stream_response_in_pool, rpc_event, state,
+                              method_handler.unary_stream, unary_request,
+                              method_handler.request_deserializer,
+                              method_handler.response_serializer)
 
 
 def _handle_stream_unary(rpc_event, state, method_handler, thread_pool):
     request_iterator = _RequestIterator(state, rpc_event.operation_call,
                                         method_handler.request_deserializer)
-    thread_pool.submit(_unary_response_in_pool, rpc_event, state,
-                       method_handler.stream_unary, lambda: request_iterator,
-                       method_handler.request_deserializer,
-                       method_handler.response_serializer)
+    return thread_pool.submit(
+        _unary_response_in_pool, rpc_event, state, method_handler.stream_unary,
+        lambda: request_iterator, method_handler.request_deserializer,
+        method_handler.response_serializer)
 
 
 def _handle_stream_stream(rpc_event, state, method_handler, thread_pool):
     request_iterator = _RequestIterator(state, rpc_event.operation_call,
                                         method_handler.request_deserializer)
-    thread_pool.submit(_stream_response_in_pool, rpc_event, state,
-                       method_handler.stream_stream, lambda: request_iterator,
-                       method_handler.request_deserializer,
-                       method_handler.response_serializer)
+    return thread_pool.submit(
+        _stream_response_in_pool, rpc_event, state,
+        method_handler.stream_stream, lambda: request_iterator,
+        method_handler.request_deserializer, method_handler.response_serializer)
 
 
 def _find_method_handler(rpc_event, generic_handlers):
@@ -549,13 +549,12 @@
         return None
 
 
-def _handle_unrecognized_method(rpc_event):
+def _reject_rpc(rpc_event, status, details):
     operations = (cygrpc.operation_send_initial_metadata(_common.EMPTY_METADATA,
                                                          _EMPTY_FLAGS),
                   cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
                   cygrpc.operation_send_status_from_server(
-                      _common.EMPTY_METADATA, cygrpc.StatusCode.unimplemented,
-                      b'Method not found!', _EMPTY_FLAGS),)
+                      _common.EMPTY_METADATA, status, details, _EMPTY_FLAGS),)
     rpc_state = _RPCState()
     rpc_event.operation_call.start_server_batch(
         operations, lambda ignored_event: (rpc_state, (),))
@@ -572,33 +571,37 @@
         state.due.add(_RECEIVE_CLOSE_ON_SERVER_TOKEN)
         if method_handler.request_streaming:
             if method_handler.response_streaming:
-                _handle_stream_stream(rpc_event, state, method_handler,
-                                      thread_pool)
+                return state, _handle_stream_stream(rpc_event, state,
+                                                    method_handler, thread_pool)
             else:
-                _handle_stream_unary(rpc_event, state, method_handler,
-                                     thread_pool)
+                return state, _handle_stream_unary(rpc_event, state,
+                                                   method_handler, thread_pool)
         else:
             if method_handler.response_streaming:
-                _handle_unary_stream(rpc_event, state, method_handler,
-                                     thread_pool)
+                return state, _handle_unary_stream(rpc_event, state,
+                                                   method_handler, thread_pool)
             else:
-                _handle_unary_unary(rpc_event, state, method_handler,
-                                    thread_pool)
-        return state
+                return state, _handle_unary_unary(rpc_event, state,
+                                                  method_handler, thread_pool)
 
 
-def _handle_call(rpc_event, generic_handlers, thread_pool):
+def _handle_call(rpc_event, generic_handlers, thread_pool,
+                 concurrency_exceeded):
     if not rpc_event.success:
-        return None
+        return None, None
     if rpc_event.request_call_details.method is not None:
         method_handler = _find_method_handler(rpc_event, generic_handlers)
         if method_handler is None:
-            return _handle_unrecognized_method(rpc_event)
+            return _reject_rpc(rpc_event, cygrpc.StatusCode.unimplemented,
+                               b'Method not found!'), None
+        elif concurrency_exceeded:
+            return _reject_rpc(rpc_event, cygrpc.StatusCode.resource_exhausted,
+                               b'Concurrent RPC limit exceeded!'), None
         else:
             return _handle_with_method_handler(rpc_event, method_handler,
                                                thread_pool)
     else:
-        return None
+        return None, None
 
 
 @enum.unique
@@ -610,7 +613,8 @@
 
 class _ServerState(object):
 
-    def __init__(self, completion_queue, server, generic_handlers, thread_pool):
+    def __init__(self, completion_queue, server, generic_handlers, thread_pool,
+                 maximum_concurrent_rpcs):
         self.lock = threading.Lock()
         self.completion_queue = completion_queue
         self.server = server
@@ -618,6 +622,8 @@
         self.thread_pool = thread_pool
         self.stage = _ServerStage.STOPPED
         self.shutdown_events = None
+        self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
+        self.active_rpc_count = 0
 
         # TODO(https://github.com/grpc/grpc/issues/6597): eliminate these fields.
         self.rpc_states = set()
@@ -657,6 +663,11 @@
         return False
 
 
+def _on_call_completed(state):
+    with state.lock:
+        state.active_rpc_count -= 1
+
+
 def _serve(state):
     while True:
         event = state.completion_queue.poll()
@@ -668,10 +679,18 @@
         elif event.tag is _REQUEST_CALL_TAG:
             with state.lock:
                 state.due.remove(_REQUEST_CALL_TAG)
-                rpc_state = _handle_call(event, state.generic_handlers,
-                                         state.thread_pool)
+                concurrency_exceeded = (
+                    state.maximum_concurrent_rpcs is not None and
+                    state.active_rpc_count >= state.maximum_concurrent_rpcs)
+                rpc_state, rpc_future = _handle_call(
+                    event, state.generic_handlers, state.thread_pool,
+                    concurrency_exceeded)
                 if rpc_state is not None:
                     state.rpc_states.add(rpc_state)
+                if rpc_future is not None:
+                    state.active_rpc_count += 1
+                    rpc_future.add_done_callback(
+                        lambda unused_future: _on_call_completed(state))
                 if state.stage is _ServerStage.STARTED:
                     _request_call(state)
                 elif _stop_serving(state):
@@ -749,12 +768,13 @@
 
 class Server(grpc.Server):
 
-    def __init__(self, thread_pool, generic_handlers, options):
+    def __init__(self, thread_pool, generic_handlers, options,
+                 maximum_concurrent_rpcs):
         completion_queue = cygrpc.CompletionQueue()
         server = cygrpc.Server(_common.channel_args(options))
         server.register_completion_queue(completion_queue)
         self._state = _ServerState(completion_queue, server, generic_handlers,
-                                   thread_pool)
+                                   thread_pool, maximum_concurrent_rpcs)
 
     def add_generic_rpc_handlers(self, generic_rpc_handlers):
         _add_generic_handlers(self._state, generic_rpc_handlers)
diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json
index 70d965d..f750b05 100644
--- a/src/python/grpcio_tests/tests/tests.json
+++ b/src/python/grpcio_tests/tests/tests.json
@@ -31,6 +31,7 @@
   "unit._invocation_defects_test.InvocationDefectsTest",
   "unit._metadata_code_details_test.MetadataCodeDetailsTest",
   "unit._metadata_test.MetadataTest",
+  "unit._resource_exhausted_test.ResourceExhaustedTest",
   "unit._rpc_test.RPCTest",
   "unit._sanity._sanity_test.Sanity",
   "unit._thread_cleanup_test.CleanupThreadTest",
diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
new file mode 100644
index 0000000..88c82b5
--- /dev/null
+++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
@@ -0,0 +1,270 @@
+# Copyright 2017, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+#     * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#     * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+#     * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""Tests server responding with RESOURCE_EXHAUSTED."""
+
+import threading
+import unittest
+
+import grpc
+from grpc import _channel
+from grpc.framework.foundation import logging_pool
+
+from tests.unit import test_common
+from tests.unit.framework.common import test_constants
+
+_REQUEST = b'\x00\x00\x00'
+_RESPONSE = b'\x00\x00\x00'
+
+_UNARY_UNARY = '/test/UnaryUnary'
+_UNARY_STREAM = '/test/UnaryStream'
+_STREAM_UNARY = '/test/StreamUnary'
+_STREAM_STREAM = '/test/StreamStream'
+
+
+class _TestTrigger(object):
+
+    def __init__(self, total_call_count):
+        self._total_call_count = total_call_count
+        self._pending_calls = 0
+        self._triggered = False
+        self._finish_condition = threading.Condition()
+        self._start_condition = threading.Condition()
+
+    # Wait for all calls be be blocked in their handler
+    def await_calls(self):
+        with self._start_condition:
+            while self._pending_calls < self._total_call_count:
+                self._start_condition.wait()
+
+    # Block in a response handler and wait for a trigger
+    def await_trigger(self):
+        with self._start_condition:
+            self._pending_calls += 1
+            self._start_condition.notify()
+
+        with self._finish_condition:
+            if not self._triggered:
+                self._finish_condition.wait()
+
+    # Finish all response handlers
+    def trigger(self):
+        with self._finish_condition:
+            self._triggered = True
+            self._finish_condition.notify_all()
+
+
+def handle_unary_unary(trigger, request, servicer_context):
+    trigger.await_trigger()
+    return _RESPONSE
+
+
+def handle_unary_stream(trigger, request, servicer_context):
+    trigger.await_trigger()
+    for _ in range(test_constants.STREAM_LENGTH):
+        yield _RESPONSE
+
+
+def handle_stream_unary(trigger, request_iterator, servicer_context):
+    trigger.await_trigger()
+    # TODO(issue:#6891) We should be able to remove this loop
+    for request in request_iterator:
+        pass
+    return _RESPONSE
+
+
+def handle_stream_stream(trigger, request_iterator, servicer_context):
+    trigger.await_trigger()
+    # TODO(issue:#6891) We should be able to remove this loop,
+    # and replace with return; yield
+    for request in request_iterator:
+        yield _RESPONSE
+
+
+class _MethodHandler(grpc.RpcMethodHandler):
+
+    def __init__(self, trigger, request_streaming, response_streaming):
+        self.request_streaming = request_streaming
+        self.response_streaming = response_streaming
+        self.request_deserializer = None
+        self.response_serializer = None
+        self.unary_unary = None
+        self.unary_stream = None
+        self.stream_unary = None
+        self.stream_stream = None
+        if self.request_streaming and self.response_streaming:
+            self.stream_stream = (
+                lambda x, y: handle_stream_stream(trigger, x, y))
+        elif self.request_streaming:
+            self.stream_unary = lambda x, y: handle_stream_unary(trigger, x, y)
+        elif self.response_streaming:
+            self.unary_stream = lambda x, y: handle_unary_stream(trigger, x, y)
+        else:
+            self.unary_unary = lambda x, y: handle_unary_unary(trigger, x, y)
+
+
+class _GenericHandler(grpc.GenericRpcHandler):
+
+    def __init__(self, trigger):
+        self._trigger = trigger
+
+    def service(self, handler_call_details):
+        if handler_call_details.method == _UNARY_UNARY:
+            return _MethodHandler(self._trigger, False, False)
+        elif handler_call_details.method == _UNARY_STREAM:
+            return _MethodHandler(self._trigger, False, True)
+        elif handler_call_details.method == _STREAM_UNARY:
+            return _MethodHandler(self._trigger, True, False)
+        elif handler_call_details.method == _STREAM_STREAM:
+            return _MethodHandler(self._trigger, True, True)
+        else:
+            return None
+
+
+class ResourceExhaustedTest(unittest.TestCase):
+
+    def setUp(self):
+        self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+        self._trigger = _TestTrigger(test_constants.THREAD_CONCURRENCY)
+        self._server = grpc.server(
+            self._server_pool,
+            handlers=(_GenericHandler(self._trigger),),
+            maximum_concurrent_rpcs=test_constants.THREAD_CONCURRENCY)
+        port = self._server.add_insecure_port('[::]:0')
+        self._server.start()
+        self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+    def tearDown(self):
+        self._server.stop(0)
+
+    def testUnaryUnary(self):
+        multi_callable = self._channel.unary_unary(_UNARY_UNARY)
+        futures = []
+        for _ in range(test_constants.THREAD_CONCURRENCY):
+            futures.append(multi_callable.future(_REQUEST))
+
+        self._trigger.await_calls()
+
+        with self.assertRaises(grpc.RpcError) as exception_context:
+            multi_callable(_REQUEST)
+
+        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+                         exception_context.exception.code())
+
+        future_exception = multi_callable.future(_REQUEST)
+        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+                         future_exception.exception().code())
+
+        self._trigger.trigger()
+        for future in futures:
+            self.assertEqual(_RESPONSE, future.result())
+
+        # Ensure a new request can be handled
+        self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
+
+    def testUnaryStream(self):
+        multi_callable = self._channel.unary_stream(_UNARY_STREAM)
+        calls = []
+        for _ in range(test_constants.THREAD_CONCURRENCY):
+            calls.append(multi_callable(_REQUEST))
+
+        self._trigger.await_calls()
+
+        with self.assertRaises(grpc.RpcError) as exception_context:
+            next(multi_callable(_REQUEST))
+
+        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+                         exception_context.exception.code())
+
+        self._trigger.trigger()
+
+        for call in calls:
+            for response in call:
+                self.assertEqual(_RESPONSE, response)
+
+        # Ensure a new request can be handled
+        new_call = multi_callable(_REQUEST)
+        for response in new_call:
+            self.assertEqual(_RESPONSE, response)
+
+    def testStreamUnary(self):
+        multi_callable = self._channel.stream_unary(_STREAM_UNARY)
+        futures = []
+        request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
+        for _ in range(test_constants.THREAD_CONCURRENCY):
+            futures.append(multi_callable.future(request))
+
+        self._trigger.await_calls()
+
+        with self.assertRaises(grpc.RpcError) as exception_context:
+            multi_callable(request)
+
+        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+                         exception_context.exception.code())
+
+        future_exception = multi_callable.future(request)
+        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+                         future_exception.exception().code())
+
+        self._trigger.trigger()
+
+        for future in futures:
+            self.assertEqual(_RESPONSE, future.result())
+
+        # Ensure a new request can be handled
+        self.assertEqual(_RESPONSE, multi_callable(request))
+
+    def testStreamStream(self):
+        multi_callable = self._channel.stream_stream(_STREAM_STREAM)
+        calls = []
+        request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
+        for _ in range(test_constants.THREAD_CONCURRENCY):
+            calls.append(multi_callable(request))
+
+        self._trigger.await_calls()
+
+        with self.assertRaises(grpc.RpcError) as exception_context:
+            next(multi_callable(request))
+
+        self.assertEqual(grpc.StatusCode.RESOURCE_EXHAUSTED,
+                         exception_context.exception.code())
+
+        self._trigger.trigger()
+
+        for call in calls:
+            for response in call:
+                self.assertEqual(_RESPONSE, response)
+
+        # Ensure a new request can be handled
+        new_call = multi_callable(request)
+        for response in new_call:
+            self.assertEqual(_RESPONSE, response)
+
+
+if __name__ == '__main__':
+    unittest.main(verbosity=2)