| # Copyright 2013 The LUCI Authors. All rights reserved. |
| # Use of this source code is governed under the Apache License, Version 2.0 |
| # that can be found in the LICENSE file. |
| |
| import base64 |
| import contextlib |
| import datetime |
| import json |
| import logging |
| import time |
| |
| import six |
| if six.PY2: |
| import webtest # only for endpoints |
| |
| from google.appengine.datastore import datastore_stub_util |
| from google.appengine.ext import ndb |
| from google.appengine.ext import testbed |
| |
| if six.PY2: # endpoints is py2-only |
| from components import endpoints_webapp2 |
| from components import utils |
| from depot_tools import auto_stub |
| |
| |
| def mock_now(test, now, seconds): |
| """Mocks utcnow() and ndb properties. |
| |
| In particular handles when auto_now and auto_now_add are used. |
| """ |
| now = now + datetime.timedelta(seconds=seconds) |
| test.mock(utils, 'utcnow', lambda: now) |
| test.mock(ndb.DateTimeProperty, '_now', lambda _: now) |
| test.mock(ndb.DateProperty, '_now', lambda _: now.date()) |
| return now |
| |
| |
| class Ticker(object): |
| def __init__(self, start, increment=None): |
| """Creates a ticker timer, which will start incrementing a timer whenever |
| __call__ is invoked. |
| Args: |
| - increment: default value is datetime.timedelta(seconds=1) |
| - start: the starting point of the timer |
| """ |
| self._current = start |
| self._start = start |
| if increment: |
| self._increment = increment |
| else: |
| self._increment = datetime.timedelta(seconds=1) |
| |
| def first(self): |
| return self._start |
| |
| def last(self): |
| return self._current |
| |
| def __call__(self): |
| old = self._current |
| self._current += self._increment |
| return old |
| |
| |
| class TestCase(auto_stub.TestCase): |
| """Support class to enable more unit testing in GAE. |
| |
| Adds support for: |
| - google.appengine.api.mail.send_mail_to_admins(). |
| - Running task queues. |
| """ |
| # See APP_DIR to the root directory containing index.yaml and queue.yaml. It |
| # will be used to assert the indexes and task queues are properly defined. It |
| # can be left to None if no index or task queue is used for the test case. |
| APP_DIR = None |
| |
| # A test can explicitly acknowledge it depends on composite indexes that may |
| # not be defined in index.yaml by setting this to True. It is valid only for |
| # components unit tests that are running outside of a context of some app |
| # (APP_DIR is None in this case). If APP_DIR is provided, GAE testbed silently |
| # overwrite index.yaml, and it's not what we want. |
| SKIP_INDEX_YAML_CHECK = False |
| |
| # If taskqueues are enqueued during the unit test, self.app must be set to a |
| # webtest.Test instance. It will be used to do the HTTP post when executing |
| # the enqueued tasks via the taskqueue module. |
| app = None |
| |
| def setUp(self): |
| """Initializes the commonly used stubs. |
| |
| Using init_all_stubs() costs ~10ms more to run all the tests so only enable |
| the ones known to be required. Test cases requiring more stubs can enable |
| them in their setUp() function. |
| """ |
| super(TestCase, self).setUp() |
| self.testbed = testbed.Testbed() |
| self.testbed.activate() |
| |
| # If you have a NeedIndexError, here is the switch you need to flip to make |
| # the new required indexes to be automatically added. Change |
| # train_index_yaml to True to have index.yaml automatically updated, then |
| # run your test case. Do not forget to put it back to False. |
| train_index_yaml = False |
| |
| if self.SKIP_INDEX_YAML_CHECK: |
| # See comment for skip_index_yaml_check above. |
| self.assertIsNone(self.APP_DIR) |
| |
| self.testbed.init_app_identity_stub() |
| self.testbed.init_datastore_v3_stub( |
| require_indexes=not train_index_yaml and not self.SKIP_INDEX_YAML_CHECK, |
| root_path=self.APP_DIR, |
| consistency_policy=datastore_stub_util.PseudoRandomHRConsistencyPolicy( |
| probability=1)) |
| if six.PY2: |
| self.testbed.init_logservice_stub() # Not in the py3 SDK |
| self.testbed.init_memcache_stub() |
| self.testbed.init_modules_stub() |
| |
| # Use mocked time in memcache. |
| memcache = self.testbed.get_stub(testbed.MEMCACHE_SERVICE_NAME) |
| memcache._gettime = lambda: int(utils.time_time()) |
| |
| # Email support. |
| self.testbed.init_mail_stub() |
| self.mail_stub = self.testbed.get_stub(testbed.MAIL_SERVICE_NAME) |
| self.old_send_to_admins = self.mock( |
| self.mail_stub, '_Dynamic_SendToAdmins', self._SendToAdmins) |
| |
| self.testbed.init_taskqueue_stub() |
| self._taskqueue_stub = self.testbed.get_stub(testbed.TASKQUEUE_SERVICE_NAME) |
| self._taskqueue_stub._root_path = self.APP_DIR |
| |
| self.testbed.init_user_stub() |
| |
| def tearDown(self): |
| try: |
| if not self.has_failed(): |
| remaining = self.execute_tasks() |
| self.assertEqual(0, remaining, |
| 'Passing tests must leave behind no pending tasks, found %d.' |
| % remaining) |
| self.testbed.deactivate() |
| finally: |
| super(TestCase, self).tearDown() |
| |
| def mock_now(self, now, seconds=0): |
| return mock_now(self, now, seconds) |
| |
| def mock_milliseconds_since_epoch(self, milliseconds): |
| self.mock(utils, "milliseconds_since_epoch", lambda: milliseconds) |
| |
| def _SendToAdmins(self, request, *args, **kwargs): |
| """Make sure the request is logged. |
| |
| See google_appengine/google/appengine/api/mail_stub.py around line 299, |
| MailServiceStub._SendToAdmins(). |
| """ |
| self.mail_stub._CacheMessage(request) |
| return self.old_send_to_admins(request, *args, **kwargs) |
| |
| def execute_tasks(self, **kwargs): |
| """Executes enqueued tasks that are ready to run and return the number run. |
| |
| A task may trigger another task. |
| |
| Sadly, taskqueue_stub implementation does not provide a nice way to run |
| them so run the pending tasks manually. |
| """ |
| self.assertEqual([None], list(self._taskqueue_stub._queues.keys())) |
| ran_total = 0 |
| while True: |
| # Do multiple loops until no task was run. |
| ran = 0 |
| for queue in self._taskqueue_stub.GetQueues(): |
| if queue['mode'] == 'pull': |
| continue |
| for task in self._taskqueue_stub.GetTasks(queue['name']): |
| # Remove 2 seconds for jitter. |
| eta = task['eta_usec'] / 1e6 - 2 |
| if eta >= time.time(): |
| continue |
| self.assertEqual('POST', task['method']) |
| logging.info('Task: %s', task['url']) |
| |
| self._post_task(task, **kwargs) |
| self._taskqueue_stub.DeleteTask(queue['name'], task['name']) |
| ran += 1 |
| if not ran: |
| return ran_total |
| ran_total += ran |
| |
| def execute_task(self, url, queue_name, payload): |
| """Executes a specified task. |
| Raise error if the task isn't in the queue. |
| """ |
| task = self._find_task(url, queue_name, payload) |
| expected = {'url': url, 'queue_name': queue_name, 'payload': payload} |
| if not task: |
| raise AssertionError("Task is not enqueued. expected: %r" % expected) |
| self._post_task(task) |
| |
| def _post_task(self, task, **kwargs): |
| # Not 100% sure why the Content-Length hack is needed, nor why the |
| # stub returns unicode values that break webtest's assertions. |
| body = base64.b64decode(task['body']) |
| headers = {k: str(v) for k, v in task['headers']} |
| headers['Content-Length'] = str(len(body)) |
| try: |
| self.app.post(task['url'], body, headers=headers, **kwargs) |
| except: |
| logging.error(task) |
| raise |
| |
| def _find_task(self, url, queue_name, payload): |
| for t in self._taskqueue_stub.GetTasks(queue_name): |
| if t['url'] != url: |
| continue |
| if t['queue_name'] != queue_name: |
| continue |
| if base64.b64decode(t['body']) != payload: |
| continue |
| return t |
| return None |
| |
| |
| if six.PY2: |
| |
| class Endpoints(object): |
| """Handles endpoints API calls.""" |
| |
| def __init__(self, api_service_cls, regex=None, source_ip='127.0.0.1'): |
| super(Endpoints, self).__init__() |
| self._api_service_cls = api_service_cls |
| kwargs = {'debug': True} |
| if regex: |
| kwargs['regex'] = regex |
| self._api_app = webtest.TestApp(endpoints_webapp2.api_server( |
| [self._api_service_cls], **kwargs), |
| extra_environ={'REMOTE_ADDR': source_ip}) |
| |
| def call_api(self, method, body=None, status=(200, 204)): |
| """Calls endpoints API method identified by its name.""" |
| # Because body is a dict and not a ResourceContainer, there's no way to |
| # tell which parameters belong in the URL and which belong in the body |
| # when the HTTP method supports both. However there's no harm in |
| # supplying parameters in both the URL and the body since |
| # ResourceContainers don't allow the same parameter name to be used in |
| # both places. Supplying parameters in both places produces no ambiguity |
| # and extraneous parameters are safely ignored. |
| assert hasattr(self._api_service_cls, method), method |
| info = getattr(self._api_service_cls, method).method_info |
| path = info.get_path(self._api_service_cls.api_info) |
| |
| # Identify which arguments are path parameters and which are query |
| # strings. |
| body = body or {} |
| query_strings = [] |
| for key, value in sorted(body.items()): |
| if '{%s}' % key in path: |
| path = path.replace('{%s}' % key, value) |
| else: |
| # We cannot tell if the parameter is a repeated field from a dict. |
| # Allow all query strings to be multi-valued. |
| if not isinstance(value, list): |
| value = [value] |
| for val in value: |
| query_strings.append('%s=%s' % (key, val)) |
| if query_strings: |
| path = '%s?%s' % (path, '&'.join(query_strings)) |
| |
| path = '/_ah/api/%s/%s/%s' % (self._api_service_cls.api_info.name, |
| self._api_service_cls.api_info.version, |
| path) |
| try: |
| if info.http_method in ('GET', 'DELETE'): |
| return self._api_app.get(path, status=status) |
| return self._api_app.post_json(path, body, status=status) |
| except Exception as e: |
| # Useful for diagnosing issues in test cases. |
| logging.info('%s failed: %s', path, e) |
| raise |
| |
| class EndpointsTestCase(TestCase): |
| """Base class for a test case that tests Cloud Endpoint Service. |
| |
| Usage: |
| class MyTestCase(test_case.EndpointsTestCase): |
| api_service_cls = MyEndpointsService |
| |
| def test_stuff(self): |
| response = self.call_api('my_method') |
| self.assertEqual(...) |
| |
| def test_expected_fail(self): |
| with self.call_should_fail(403): |
| self.call_api('protected_method') |
| """ |
| # Should be set in subclasses to a subclass of remote.Service. |
| api_service_cls = None |
| # Should be set in subclasses to a regular expression to match against path |
| # parameters. See components.endpoints_webapp2.adapter.api_server. |
| api_service_regex = None |
| |
| # See call_should_fail. |
| expected_fail_status = None |
| |
| _endpoints = None |
| |
| def setUp(self): |
| super(EndpointsTestCase, self).setUp() |
| self._endpoints = Endpoints(self.api_service_cls, |
| regex=self.api_service_regex) |
| |
| def call_api(self, method, body=None, status=(200, 204)): |
| if self.expected_fail_status: |
| status = self.expected_fail_status |
| return self._endpoints.call_api(method, body, status) |
| |
| @contextlib.contextmanager |
| def call_should_fail(self, status): |
| """Asserts that Endpoints call inside the guarded region of code fails.""" |
| # TODO(vadimsh): Get rid of this function and just use |
| # call_api(..., status=...). It existed as a workaround for bug that has |
| # been fixed: |
| # https://code.google.com/p/googleappengine/issues/detail?id=10544 |
| assert self.expected_fail_status is None, 'nested call_should_fail' |
| assert status is not None |
| self.expected_fail_status = int(status) |
| try: |
| yield |
| except AssertionError: |
| # Assertion can happen if tests are running on GAE < 1.9.31, where |
| # endpoints bug still exists (and causes webapp guts to raise |
| # assertion). It should be rare (since we are switching to GAE >= |
| # 1.9.31), so don't bother to check that assertion was indeed raised. |
| # Just skip it if it did. |
| pass |
| finally: |
| self.expected_fail_status = None |