blob: 526b279fbb0b7e7fcc0655e37e86c3d35474aac8 [file] [log] [blame]
#!/usr/bin/env vpython3
# Copyright 2023 The LUCI Authors. All rights reserved.
# Use of this source code is governed under the Apache License, Version 2.0
# that can be found in the LICENSE file.
"""Unit Tests for pycurl.py"""
import io
import os
import requests
import sys
import tempfile
import unittest
from unittest import mock
import pycurl
class PyCurlTest(unittest.TestCase):
def setUp(self):
mock.patch('requests.Session').start()
self.addCleanup(mock.patch.stopall)
def testSuccess(self):
r = requests.Session().get()
r.status_code = requests.codes.ok
r.headers = {'Content-Length': '2'}
r.raw = io.BytesIO(b'ok')
r.iter_content.return_value = r.raw
code, total = pycurl._download('https://test/', os.devnull, None, 0, '')
self.assertTrue(code == requests.codes.ok)
self.assertTrue(total == 2)
def testShortRead(self):
r = requests.Session().get()
r.status_code = requests.codes.ok
r.headers = {'Content-Length': '6'}
r.raw = io.BytesIO(b'short')
r.iter_content.return_value = r.raw
with self.assertRaises(ValueError) as context:
pycurl._download('https://test/', os.devnull, None, 0, '')
self.assertTrue('Expected content length:' in str(context.exception))
def testInvalidContentLength(self):
r = requests.Session().get()
r.status_code = requests.codes.ok
r.headers = {'Content-Length': 'abc'}
r.raw = io.BytesIO(b'anything')
r.iter_content.return_value = r.raw
with self.assertRaises(ValueError) as context:
pycurl._download('https://test/', os.devnull, None, 0, '')
self.assertTrue('invalid literal for int()' in str(context.exception))
def testWithoutContentLength(self):
r = requests.Session().get()
r.status_code = requests.codes.ok
r.headers = {}
r.raw = io.BytesIO(b'anything')
r.iter_content.return_value = r.raw
code, total = pycurl._download('https://test/', os.devnull, None, 0, '')
self.assertTrue(code == requests.codes.ok)
self.assertTrue(total == 8)
def testStripPrefix(self):
r = requests.Session().get()
r.status_code = requests.codes.ok
r.headers = {}
r.raw = io.BytesIO(b")]}'\nok")
r.iter_content.return_value = r.raw
outfile = tempfile.NamedTemporaryFile()
code, total = pycurl._download('https://test/', outfile.name, None, 0,
")]}'\n")
self.assertTrue(outfile.read(), b'ok')
if __name__ == '__main__':
if '-v' in sys.argv:
logging.basicConfig(level=logging.DEBUG)
unittest.main()