blob: bf5a663000d1cab2cb3a0c249bd9e2b1d516ef85 [file] [log] [blame]
# Copyright 2018 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Implements oauth2client's OAuth2Credentials subclass that uses LUCI local
auth (usually on Swarming) for getting access tokens.
"""
import collections
import datetime
import json
import logging
import httplib2
import oauth2client.client
from infra_libs import luci_ctx
DEFAULT_SCOPES = ('https://www.googleapis.com/auth/userinfo.email',)
class LUCIAuthError(Exception):
"""Raised if there's unexpected error when getting access token from LUCI."""
def available(environ=None):
"""Returns True if the environment has ambient LUCI auth setup."""
params = _local_auth_params(environ)
return params and params.default_account_id
def get_access_token(scopes=DEFAULT_SCOPES, environ=None, http=None):
"""Grabs an access token from LUCI local server.
Does RPC each time it is called. For efficiency, either cache the token until
it expires, or use LUCICredentials class that does it itself already.
Args:
scopes: list of OAuth scopes the constructed token should have.
environ: an environ dict to use instead of os.environ, for tests.
http: Http instance to use for request, default is new httplib2.Http.
Returns:
Tuple (access token str, datatime.datetime when it expires or None if not).
Raises:
LUCIAuthError on unexpected errors.
"""
_check_scopes(scopes)
params = _local_auth_params(environ)
if not params or not params.default_account_id:
raise LUCIAuthError('LUCI auth is not available')
logging.debug(
'luci_auth: requesting an access token for account "%s"',
params.default_account_id)
http = http or httplib2.Http()
host = '127.0.0.1:%d' % params.rpc_port
resp, content = http.request(
uri='http://%s/rpc/LuciLocalAuthService.GetOAuthToken' % host,
method='POST',
body=json.dumps(
{
'account_id': params.default_account_id,
'scopes': scopes,
'secret': params.secret,
},
# Sort the keys to simplify mocking.
sort_keys=True),
headers={'Content-Type': 'application/json'})
if resp.status != 200:
raise LUCIAuthError(
'Failed to grab access token from local auth server %s, status %d: %r' %
(host, resp.status, content))
try:
token = json.loads(content)
except ValueError as e:
raise LUCIAuthError('Cannot parse the local auth server reply: %s' % e)
error_code = token.get('error_code')
error_message = token.get('error_message')
access_token = token.get('access_token')
expiry = token.get('expiry')
if error_code:
raise LUCIAuthError(
'Error #%d when retrieving access token: %s' %
(error_code, error_message))
if expiry is None:
logging.debug(
'luci_auth: got an access token for account "%s" that does not expire',
params.default_account_id)
return access_token, None
exp = datetime.datetime.utcfromtimestamp(expiry)
logging.debug(
'luci_auth: got an access token for account "%s" that expires in %d sec',
params.default_account_id,
(exp - datetime.datetime.utcnow()).total_seconds())
return str(access_token), exp
class LUCICredentials(oauth2client.client.OAuth2Credentials):
"""Credentials that use LUCI local auth to get access tokens."""
def __init__(self, scopes=DEFAULT_SCOPES, environ=None, http=None):
"""Create a LUCICredentials that request tokens with given scopes.
Args:
scopes: list of OAuth scopes the constructed token should have.
environ: an environ dict to use instead of os.environ, for tests.
http: Http instance to use for request, default is new httplib2.Http.
"""
if not available(environ=environ):
raise LUCIAuthError('LUCI auth is not available')
super(LUCICredentials, self).__init__(
None, None, None, None, None, None, None)
_check_scopes(scopes)
self._scopes = tuple(scopes)
self._environ = environ
self._http = http
def _refresh(self, _http_request):
# pylint: disable=attribute-defined-outside-init
# All fields here are 'inherited' from OAuth2Credentials.
self.access_token, self.token_expiry = get_access_token(
self._scopes,
environ=self._environ,
http=self._http)
self.invalid = False
###
# Represents LUCI_CONTEXT['local_auth'] dict.
_LocalAuthParams = collections.namedtuple(
'_LocalAuthParams', [
'default_account_id',
'secret',
'rpc_port',
])
def _local_auth_params(environ):
"""Reads _LocalAuthParams from LUCI_CONTEXT, returns None if not there."""
p = luci_ctx.read('local_auth', environ=environ)
if not p:
return None
try:
return _LocalAuthParams(
default_account_id=p.get('default_account_id'),
secret=p.get('secret'),
rpc_port=int(p.get('rpc_port')))
except ValueError as e:
raise LUCIAuthError('local_auth section is malformed: %s' % e)
def _check_scopes(scopes):
"""Raises TypeError is list of scopes is malformed."""
if not isinstance(scopes, (list, tuple)):
raise TypeError('Scopes must be a list or a tuple')
if not all(isinstance(s, str) for s in scopes):
raise TypeError('Each scope must be a string')