blob: 0e09d8dfd8644f69c856d6d2d0fd2622e24250f3 [file]
#!/usr/bin/env python
# Copyright 2016 The LUCI Authors. All rights reserved.
# Use of this source code is governed under the Apache License, Version 2.0
# that can be found in the LICENSE file.
import contextlib
import json
import logging
import os
import socket
import sys
import time
import unittest
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(
__file__.decode(sys.getfilesystemencoding()))))
sys.path.insert(0, ROOT_DIR)
sys.path.insert(0, os.path.join(ROOT_DIR, 'third_party'))
from depot_tools import auto_stub
from depot_tools import fix_encoding
from third_party import requests
from utils import authenticators
from utils import auth_server
from utils import net
from utils import oauth
from libs import luci_context
import net_utils
def global_test_setup():
# Terminate HTTP server in tests 50x faster. Impacts performance though so
# do it only in tests.
auth_server._HTTPServer.poll_interval = 0.01
def call_rpc(account_id, scopes):
ctx = luci_context.read('local_auth')
r = requests.post(
url='http://127.0.0.1:%d/rpc/LuciLocalAuthService.GetOAuthToken' %
ctx['rpc_port'],
data=json.dumps({
'account_id': account_id,
'scopes': scopes,
'secret': ctx['secret'],
}),
headers={'Content-Type': 'application/json'})
return r.json()
@contextlib.contextmanager
def local_auth_server(token_cb, default_account_id, **overrides):
class MockedProvider(object):
def generate_token(self, account_id, scopes):
return token_cb(account_id, scopes)
s = auth_server.LocalAuthServer()
try:
local_auth = s.start(
token_provider=MockedProvider(),
accounts=('acc_1', 'acc_2', 'acc_3'),
default_account_id=default_account_id)
local_auth.update(overrides)
with luci_context.write(local_auth=local_auth):
yield
finally:
s.stop()
class LocalAuthServerTest(auto_stub.TestCase):
epoch = 12345678
def setUp(self):
super(LocalAuthServerTest, self).setUp()
self.mock_time(0)
def mock_time(self, delta):
self.mock(time, 'time', lambda: self.epoch + delta)
def test_works(self):
calls = []
def token_gen(account_id, scopes):
calls.append((account_id, scopes))
return auth_server.AccessToken('tok_%s' % account_id, time.time() + 300)
with local_auth_server(token_gen, 'acc_1'):
# Grab initial token.
resp = call_rpc('acc_1', ['B', 'B', 'A', 'C'])
self.assertEqual(
{u'access_token': u'tok_acc_1', u'expiry': self.epoch + 300}, resp)
self.assertEqual([('acc_1', ('A', 'B', 'C'))], calls)
del calls[:]
# Reuses cached token until it is close to expiration.
self.mock_time(60)
resp = call_rpc('acc_1', ['B', 'A', 'C'])
self.assertEqual(
{u'access_token': u'tok_acc_1', u'expiry': self.epoch + 300}, resp)
self.assertFalse(calls)
# Asking for different account gives another token.
resp = call_rpc('acc_2', ['B', 'B', 'A', 'C'])
self.assertEqual(
{u'access_token': u'tok_acc_2', u'expiry': self.epoch + 360}, resp)
self.assertEqual([('acc_2', ('A', 'B', 'C'))], calls)
del calls[:]
# First token has expired. Generated new one.
self.mock_time(300)
resp = call_rpc('acc_1', ['A', 'B', 'C'])
self.assertEqual(
{u'access_token': u'tok_acc_1', u'expiry': self.epoch + 600}, resp)
self.assertEqual([('acc_1', ('A', 'B', 'C'))], calls)
def test_handles_token_errors(self):
calls = []
def token_gen(_account_id, _scopes):
calls.append(1)
raise auth_server.TokenError(123, 'error message')
with local_auth_server(token_gen, 'acc_1'):
self.assertEqual(
{u'error_code': 123, u'error_message': u'error message'},
call_rpc('acc_1', ['B', 'B', 'A', 'C']))
self.assertEqual(1, len(calls))
# Errors are cached. Same error is returned.
self.assertEqual(
{u'error_code': 123, u'error_message': u'error message'},
call_rpc('acc_1', ['B', 'B', 'A', 'C']))
self.assertEqual(1, len(calls))
def test_http_level_errors(self):
def token_gen(_account_id, _scopes):
self.fail('must not be called')
with local_auth_server(token_gen, 'acc_1'):
# Wrong URL.
ctx = luci_context.read('local_auth')
r = requests.post(
url='http://127.0.0.1:%d/blah/LuciLocalAuthService.GetOAuthToken' %
ctx['rpc_port'],
data=json.dumps({
'account_id': 'acc_1',
'scopes': ['A', 'B', 'C'],
'secret': ctx['secret'],
}),
headers={'Content-Type': 'application/json'})
self.assertEqual(404, r.status_code)
# Wrong HTTP method.
r = requests.get(
url='http://127.0.0.1:%d/rpc/LuciLocalAuthService.GetOAuthToken' %
ctx['rpc_port'],
data=json.dumps({
'account_id': 'acc_1',
'scopes': ['A', 'B', 'C'],
'secret': ctx['secret'],
}),
headers={'Content-Type': 'application/json'})
self.assertEqual(501, r.status_code)
# Wrong content type.
r = requests.post(
url='http://127.0.0.1:%d/rpc/LuciLocalAuthService.GetOAuthToken' %
ctx['rpc_port'],
data=json.dumps({
'account_id': 'acc_1',
'scopes': ['A', 'B', 'C'],
'secret': ctx['secret'],
}),
headers={'Content-Type': 'application/xml'})
self.assertEqual(400, r.status_code)
# Bad JSON.
r = requests.post(
url='http://127.0.0.1:%d/rpc/LuciLocalAuthService.GetOAuthToken' %
ctx['rpc_port'],
data='not a json',
headers={'Content-Type': 'application/json'})
self.assertEqual(400, r.status_code)
def test_validation(self):
def token_gen(_account_id, _scopes):
self.fail('must not be called')
with local_auth_server(token_gen, 'acc_1'):
ctx = luci_context.read('local_auth')
def must_fail(body, err, code):
r = requests.post(
url='http://127.0.0.1:%d/rpc/LuciLocalAuthService.GetOAuthToken' %
ctx['rpc_port'],
data=json.dumps(body),
headers={'Content-Type': 'application/json'})
self.assertEqual(code, r.status_code)
self.assertIn(err, r.text)
cases = [
# account_id
({}, '"account_id" is required', 400),
({'account_id': 123}, '"account_id" must be a string', 400),
# scopes
({'account_id': 'acc_1'}, '"scopes" is required', 400),
({'account_id': 'acc_1', 'scopes': []}, '"scopes" is required', 400),
(
{'account_id': 'acc_1', 'scopes': 'abc'},
'"scopes" must be a list of strings',
400,
),
(
{'account_id': 'acc_1', 'scopes': [1]},
'"scopes" must be a list of strings',
400,
),
# secret
({'account_id': 'acc_1', 'scopes': ['a']}, '"secret" is required', 400),
(
{'account_id': 'acc_1', 'scopes': ['a'], 'secret': 123},
'"secret" must be a string',
400,
),
(
{'account_id': 'acc_1', 'scopes': ['a'], 'secret': 'abc'},
'Invalid "secret"',
403,
),
# The account is known.
(
{'account_id': 'zzz', 'scopes': ['a'], 'secret': ctx['secret']},
'Unrecognized account ID',
404,
),
]
for body, err, code in cases:
must_fail(body, err, code)
class LocalAuthHttpServiceTest(auto_stub.TestCase):
"""Tests for LocalAuthServer and LuciContextAuthenticator."""
epoch = 12345678
def setUp(self):
super(LocalAuthHttpServiceTest, self).setUp()
self.mock_time(0)
def mock_time(self, delta):
self.mock(time, 'time', lambda: self.epoch + delta)
@staticmethod
def mocked_http_service(
url='http://example.com',
perform_request=None):
class MockedRequestEngine(object):
def perform_request(self, request):
return perform_request(request) if perform_request else None
@classmethod
def timeout_exception_classes(cls):
return ()
@classmethod
def parse_request_exception(cls, exc):
del exc # Unused argument
return None, None
return net.HttpService(
url,
authenticator=authenticators.LuciContextAuthenticator(),
engine=MockedRequestEngine())
def test_works(self):
service_url = 'http://example.com'
request_url = '/some_request'
response = 'True'
token = 'notasecret'
def token_gen(account_id, scopes):
self.assertEqual('acc_1', account_id)
self.assertEqual(1, len(scopes))
self.assertEqual(oauth.OAUTH_SCOPES, scopes[0])
return auth_server.AccessToken(token, time.time() + 300)
def handle_request(request):
self.assertTrue(
request.get_full_url().startswith(service_url + request_url))
self.assertEqual('', request.body)
self.assertEqual(u'Bearer %s' % token,
request.headers['Authorization'])
return net_utils.make_fake_response(response, request.get_full_url())
with local_auth_server(token_gen, 'acc_1'):
service = self.mocked_http_service(perform_request=handle_request)
self.assertEqual(service.request(request_url, data={}).read(), response)
def test_bad_secret(self):
service_url = 'http://example.com'
request_url = '/some_request'
response = 'False'
def token_gen(_account_id, _scopes):
self.fail('must not be called')
def handle_request(request):
self.assertTrue(
request.get_full_url().startswith(service_url + request_url))
self.assertEqual('', request.body)
self.assertIsNone(request.headers.get('Authorization'))
return net_utils.make_fake_response(response, request.get_full_url())
with local_auth_server(token_gen, 'acc_1', secret='invalid'):
service = self.mocked_http_service(perform_request=handle_request)
self.assertEqual(service.request(request_url, data={}).read(), response)
def test_bad_port(self):
request_url = '/some_request'
def token_gen(_account_id, _scopes):
self.fail('must not be called')
def handle_request(_request):
self.fail('must not be called')
# This little dance should pick an unused port, bind it and then close it,
# trusting that the OS will not reallocate it between now and when the http
# client attempts to use it as a local_auth service. This is better than
# picking a static port number, as there's at least some guarantee that the
# port WASN'T in use before this test ran.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('localhost', 0))
port = sock.getsockname()[1]
sock.close()
with local_auth_server(token_gen, 'acc_1', rpc_port=port):
service = self.mocked_http_service(perform_request=handle_request)
with self.assertRaises(socket.error):
self.assertRaises(service.request(request_url, data={}).read())
def test_expired_token(self):
service_url = 'http://example.com'
request_url = '/some_request'
response = 'False'
token = 'notasecret'
def token_gen(account_id, scopes):
self.assertEqual('acc_1', account_id)
self.assertEqual(1, len(scopes))
self.assertEqual(oauth.OAUTH_SCOPES, scopes[0])
return auth_server.AccessToken(token, time.time())
def handle_request(request):
self.assertTrue(
request.get_full_url().startswith(service_url + request_url))
self.assertEqual('', request.body)
self.assertIsNone(request.headers.get('Authorization'))
return net_utils.make_fake_response(response, request.get_full_url())
with local_auth_server(token_gen, 'acc_1'):
service = self.mocked_http_service(perform_request=handle_request)
self.assertEqual(service.request(request_url, data={}).read(), response)
if __name__ == '__main__':
fix_encoding.fix_encoding()
logging.basicConfig(
level=logging.DEBUG if '-v' in sys.argv else logging.CRITICAL)
global_test_setup()
unittest.main()