blob: 86fb4785cdd33ab1f5ca99906ff51f564ea6df19 [file] [log] [blame] [edit]
# Copyright 2014 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""deferred_resource converts blocking apiclient resource to deferred."""
import collections
import datetime
import functools
import httplib
import ssl
import threading
import traceback
from twisted.internet import defer, reactor, threads
from twisted.python import log as twistedLog
from twisted.python.threadpool import ThreadPool
import apiclient
import apiclient.discovery
import oauth2client
from infra_libs import InstrumentedHttp
DEFAULT_RETRY_ATTEMPT_COUNT = 5
DEFAULT_RETRY_WAIT_SECONDS = 1
if httplib.FORBIDDEN not in oauth2client.client.REFRESH_STATUS_CODES:
oauth2client.client.REFRESH_STATUS_CODES.append(httplib.FORBIDDEN)
class NotStartedError(Exception):
pass
class CredentialFactory(object):
"""Creates credentials."""
# ttl (datetime.timedelta) is TTL for created credentials.
# If None, credentials can live forever.
ttl = None
def __init__(self, f, ttl=None):
self.f = f
self.ttl = ttl
def __call__(self, *args, **kwargs):
return self.f(*args, **kwargs)
class DaemonThreadPool(ThreadPool):
def threadFactory(self, *args, **kwargs):
thread = threading.Thread(*args, **kwargs)
thread.daemon = True
return thread
class DeferredResource(object):
"""Wraps an apiclient Resource, converts its methods to deferred.
Accepts an apiclient.Resource, such as one generated by
apiclient.discovery.build, and wraps all resource methods. When deferrred
resource method is called, it schedules an actual rpc in a twisted thread pool
and returns a Deferred.
Has to be explicitly started and stopped. This can be done using "with"
statement, see examples.
Examples:
Basic usage:
@defer.inlineCallbacks
def greet():
# Asynchronously build a DeferredResource for my_greeting_service API.
service = yield DeferredResource.build('my_greeting_service', 'v1')
with service:
response = yield res.api.greet('John')
defer.returnValue(response)
Authorization:
with open(secret_key_filename, 'rb') as f:
secret_key = f.read()
AUTH_SCOPE = 'https://www.googleapis.com/auth/userinfo.email'
creds = SignedJwtAssertionCredentials(service_account, secret_key,
AUTH_SCOPE)
service = yield DeferredResource.build(
'my_greeting_service', 'v1', credentials=creds)
Also DeferredResource retries requests on transient errors with exponential
backoff.
"""
class Api(object):
"""Wraps an apiclient resource or method.
Can dynamically create Api objects for nested methods.
"""
def __init__(self, owner, path=None):
self._cached_method = None
self._api_cache = {}
self._owner = owner
self._path = tuple(path) if path else ()
def __call__(self, *args, **kwargs):
if not self._cached_method:
self._cached_method = self._owner._twistify(self._path)
if not self._cached_method:
raise AttributeError(
'Resource does not have method %s' % '.'.join(self.prefix))
return self._cached_method(*args, **kwargs)
def __getattr__(self, name):
new_prefix = self._path + (name,)
sub_api = self._api_cache.get(new_prefix)
if not sub_api:
sub_api = self._owner.Api(self._owner, path=self._path + (name,))
self._api_cache[new_prefix] = sub_api
return sub_api
def __init__(
self, resource, credentials=None, max_concurrent_requests=1,
retry_wait_seconds=None, retry_attempt_count=None, verbose=False,
log_prefix='', timeout=None, _pool=None, http_client_name=None):
"""Creates a DeferredResource.
Args:
resource (apiclient.Resource): a resource, such as one generated by
apiclient.discovery.build.
credentials (oauth2client.client.Credentials or CredentialFactory):
credentials to use to make API requests.
max_concurrent_requests (int): maximum number of concurrent requests.
Defaults to 1.
retry_wait_seconds (int, float): initial wait interval for request
retrial. In seconds, defaults to 1.
retry_attempt_count (int): number of attempts before giving up.
Defaults to 5.
verbose (bool): if True, log each request/response.
log_prefix (str): prefix for log messages.
timeout (int): request timeout in seconds. If None is passed
then Python's default timeout for sockets will be used. See
for example the docs of socket.setdefaulttimeout():
http://docs.python.org/library/socket.html#socket.setdefaulttimeout
http_client_name (str): an identifier for the HTTP requests made by this
resource. Included with monitoring metrics.
"""
max_concurrent_requests = max_concurrent_requests or 1
assert resource, 'resource not specified'
if retry_wait_seconds is None:
retry_wait_seconds = DEFAULT_RETRY_WAIT_SECONDS
assert isinstance(retry_wait_seconds, (int, float))
if retry_attempt_count is None:
retry_attempt_count = DEFAULT_RETRY_ATTEMPT_COUNT
assert isinstance(retry_attempt_count, int)
assert http_client_name
self._pool = _pool or self._create_thread_pool(
max_concurrent_requests, http_client_name)
self._resource = resource
self.credentials = credentials
self.retry_wait_seconds = retry_wait_seconds
self.retry_attempt_count = retry_attempt_count
self.verbose = verbose
self.log_prefix = log_prefix
self.api = self.Api(self)
self._th_local = threading.local()
self.started = False
self.timeout = timeout
self.http_client_name = http_client_name
@classmethod
def _create_thread_pool(cls, max_concurrent_requests, name):
return DaemonThreadPool(
minthreads=1, maxthreads=max_concurrent_requests, name=name)
@classmethod
def _create_async(
cls, resource_factory, max_concurrent_requests=1, _pool=None,
http_client_name=None, **kwargs):
_pool = _pool or cls._create_thread_pool(
max_concurrent_requests, http_client_name)
result = defer.Deferred()
def create_sync():
# Stop the thread pool after creating DeferredResource.
reactor.callFromThread(_pool.stop)
try:
assert resource_factory, 'resource_factory is not specified'
res = resource_factory()
def_res = cls(res, _pool=_pool, http_client_name=http_client_name,
**kwargs)
reactor.callFromThread(result.callback, def_res)
except Exception as ex:
reactor.callFromThread(result.errback, ex)
_pool.start()
_pool.callInThread(create_sync)
return result
# Yes, I've copied all these parameters because being explicit is good.
@classmethod
def build(
cls, service_name, version, credentials=None, max_concurrent_requests=1,
discoveryServiceUrl=apiclient.discovery.DISCOVERY_URI,
developerKey=None, model=None,
requestBuilder=apiclient.http.HttpRequest,
retry_wait_seconds=None, retry_attempt_count=None, verbose=False,
log_prefix='', timeout=None, http_client_name=None):
"""Asynchronously builds a DeferredResource for a discoverable API.
Asynchronously builds a resource by calling apiclient.discovery.build and
wraps it with a DeferredResource.
Args:
serviceName: string, name of the service.
version: string, the version of the service.
credentials (oauth2client.client.Credentials or CredentialFactory):
credentials to use to make API requests.
max_concurrent_requests (int): maximum number of concurrent requests.
Defaults to 1.
discoveryServiceUrl: string, a URI Template that points to the location of
the discovery service. It should have two parameters {api} and
{apiVersion} that when filled in produce an absolute URI to the
discovery document for that service.
developerKey: string, key obtained from
https://code.google.com/apis/console.
model: apiclient.Model, converts to and from the wire format.
requestBuilder: apiclient.http.HttpRequest, encapsulator for an HTTP
request.
retry_wait_seconds (int, float): initial wait interval for request
retrial. In seconds, defaults to 1.
retry_attempt_count (int): number of attempts before giving up.
Defaults to 5.
verbose (bool): if True, log each request/response.
log_prefix (str): prefix for log messages.
timeout (int): request timeout in seconds. If None is passed
then Python's default timeout for sockets will be used. See
for example the docs of socket.setdefaulttimeout():
http://docs.python.org/library/socket.html#socket.setdefaulttimeout
http_client_name (str): an identifier for the HTTP requests made by this
resource. Included with monitoring metrics.
Returns:
A DeferredResource as Deferred.
"""
# Do not check arguments synchronously. Let the client check for exceptions
# only in errback.
def resource_factory():
return apiclient.discovery.build(
service_name,
version,
discoveryServiceUrl=discoveryServiceUrl,
developerKey=developerKey,
requestBuilder=requestBuilder,
)
return cls._create_async(
resource_factory,
credentials=credentials,
max_concurrent_requests=max_concurrent_requests,
retry_wait_seconds=retry_wait_seconds,
retry_attempt_count=retry_attempt_count,
verbose=verbose,
log_prefix=log_prefix,
timeout=timeout,
http_client_name=http_client_name,
)
def log(self, message):
twistedLog.msg('%s%s' % (self.log_prefix, message))
def start(self):
self._pool.start()
self.started = True
def stop(self):
self.started = False
self._pool.stop()
def __enter__(self):
self.start()
return self
def __exit__(self, *args, **kwrags):
self.stop()
@defer.inlineCallbacks
def _retry(self, method_name, call):
"""Retries |call| on transient errors and access token expiration.
Args:
method_name (str): name of the remote method, for logging.
call (func() -> any): a function that makes an RPC call and returns
result.
"""
attempts = self.retry_attempt_count
wait = self.retry_wait_seconds
while attempts > 0:
attempts -= 1
try:
if not self.started:
raise NotStartedError('DeferredResource is not started')
res = yield threads.deferToThreadPool(reactor, self._pool, call)
defer.returnValue(res)
except Exception as ex:
if not self.started:
raise ex
if attempts > 0 and is_transient(ex):
self.log('Transient error while calling %s. '
'Will retry in %d seconds.' % (method_name, wait))
# TODO(nodir), optimize: stop waiting if the resource is stopped.
yield sleep(wait)
if not self.started:
raise ex
wait = min(wait * 2, 30)
continue
self.log('RPC "%s" failed: %s'% (method_name, traceback.format_exc()))
raise ex
def _log_request(self, method_name, args, kwargs):
arg_str_list = map(repr, args)
arg_str_list += ['%s=%r' % (k, v) for k, v in kwargs.iteritems()]
self.log('Request %s(%s)' % (method_name, ', '.join(arg_str_list)))
def _twistify(self, path):
"""Wraps a resource method by name."""
method = self._resource
for component in path:
if isinstance(method, collections.Callable):
method = method()
method = getattr(method, component, None)
if method is None:
return None
@functools.wraps(method)
def twistified(*args, **kwargs):
def single_call():
now = datetime.datetime.utcnow()
create_creds = False
if getattr(self._th_local, 'http', None) is None:
create_creds = True
elif (self._th_local.credentials_expiry is not None and
self._th_local.credentials_expiry <= now):
create_creds = True
if create_creds:
self._th_local.credentials = None
self._th_local.credentials_expiry = None
self._th_local.http = InstrumentedHttp(
self.http_client_name, timeout=self.timeout)
if self.credentials is not None:
creds = self.credentials
ttl = None
if isinstance(creds, CredentialFactory):
ttl = creds.ttl
creds = creds()
self._th_local.credentials = creds.from_json(creds.to_json())
if ttl is not None:
self._th_local.credentials_expiry = now + ttl
self._th_local.http = self._th_local.credentials.authorize(
self._th_local.http)
elif getattr(self._th_local.credentials, 'token_expiry', None):
# Check token_expiry more aggressively:
# refresh if it expires in <= 5 min.
expiry = (
self._th_local.credentials.token_expiry -
datetime.timedelta(minutes=5))
if expiry <= now:
self._th_local.credentials.refresh(self._th_local.http)
if self.verbose:
self._log_request(path, args, kwargs)
response = method(*args, **kwargs).execute(http=self._th_local.http)
if self.verbose:
self.log('Reponse: %s' % response)
return response
return self._retry(path, single_call)
return twistified
def sleep(secs):
d = defer.Deferred()
reactor.callLater(secs, d.callback, None)
return d
def is_transient(ex):
if isinstance(ex, apiclient.errors.HttpError) and ex.resp:
return ex.resp.status >= 500;
if isinstance(ex, ssl.SSLError):
# No reason, no errcode.
return "timed out" in str(ex)
return False