blob: 805e2a77bc318b9225a15c4c1af3a6e1ba0446da [file] [log] [blame]
# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
from __future__ import annotations
import argparse
import datetime as dt
import json
import math
import pathlib
import unittest
from typing import Any
from urllib import parse as urlparse
from typing_extensions import override
from crossbench import path as pth
from crossbench.parse import (DurationParseError, DurationParser, NumberParser,
ObjectParser, PathParser, TimeUnit)
from protoc import trace_config_pb2
from tests import test_helper
from tests.crossbench.base import CrossbenchFakeFsTestCase
class DurationParserTestCase(unittest.TestCase):
def test_parse_negative(self):
with self.assertRaises(argparse.ArgumentTypeError):
DurationParser.positive_duration(-1)
with self.assertRaises(argparse.ArgumentTypeError) as cm:
DurationParser.positive_duration("-1")
with self.assertRaises(argparse.ArgumentTypeError) as cm:
DurationParser.positive_or_zero_duration("-1")
with self.assertRaises(argparse.ArgumentTypeError) as cm:
DurationParser.positive_or_zero_duration(dt.timedelta(seconds=-1))
with self.assertRaises(argparse.ArgumentTypeError) as cm:
DurationParser.positive_duration("-1")
with self.assertRaises(argparse.ArgumentTypeError) as cm:
DurationParser.positive_duration(dt.timedelta(seconds=-1))
self.assertIn("-1", str(cm.exception))
self.assertEqual(DurationParser.any_duration("-1.5").total_seconds(), -1.5)
def test_parse_zero(self):
self.assertEqual(DurationParser.any_duration("0").total_seconds(), 0)
self.assertEqual(DurationParser.any_duration("0s").total_seconds(), 0)
self.assertEqual(DurationParser.any_duration("0.0").total_seconds(), 0)
self.assertEqual(
DurationParser.positive_or_zero_duration("0.0").total_seconds(), 0)
invalid: Any
for invalid in (-1, 0, "-1", "0", "invalid", dt.timedelta(0),
dt.timedelta(seconds=-1)):
with self.assertRaises(argparse.ArgumentTypeError) as cm:
DurationParser.positive_duration(invalid)
self.assertIn(str(invalid), str(cm.exception))
with self.assertRaises(argparse.ArgumentTypeError) as cm:
DurationParser.positive_duration(invalid)
self.assertIn(str(invalid), str(cm.exception))
def test_parse_empty(self):
with self.assertRaises(argparse.ArgumentTypeError):
DurationParser.positive_duration("")
with self.assertRaises(argparse.ArgumentTypeError):
DurationParser.any_duration("")
with self.assertRaises(argparse.ArgumentTypeError):
DurationParser.positive_or_zero_duration("")
with self.assertRaises(argparse.ArgumentTypeError):
DurationParser.positive_duration("")
def test_invalid_suffix(self):
with self.assertRaises(argparse.ArgumentTypeError) as cm:
DurationParser.positive_duration("100XXX")
self.assertIn("Unknown duration format", str(cm.exception))
with self.assertRaises(argparse.ArgumentTypeError):
DurationParser.positive_duration("X0XX")
with self.assertRaises(argparse.ArgumentTypeError):
DurationParser.positive_duration("100X0XX")
def test_no_unit(self):
self.assertEqual(
DurationParser.positive_duration("200"), dt.timedelta(seconds=200))
self.assertEqual(
DurationParser.positive_duration(200), dt.timedelta(seconds=200))
self.assertEqual(
DurationParser.positive_duration_ms("200"),
dt.timedelta(milliseconds=200))
self.assertEqual(
DurationParser.positive_duration_ms(200),
dt.timedelta(milliseconds=200))
def test_milliseconds(self):
self.assertEqual(
DurationParser.positive_duration("27.5ms"),
dt.timedelta(milliseconds=27.5))
self.assertEqual(
DurationParser.positive_duration(dt.timedelta(milliseconds=27.5)),
dt.timedelta(milliseconds=27.5))
self.assertEqual(
DurationParser.positive_duration("27.5 millis"),
dt.timedelta(milliseconds=27.5))
self.assertEqual(
DurationParser.positive_duration("27.5 milliseconds"),
dt.timedelta(milliseconds=27.5))
def test_seconds(self):
self.assertEqual(
DurationParser.positive_duration("27.5s"), dt.timedelta(seconds=27.5))
self.assertEqual(
DurationParser.positive_duration("1 sec"), dt.timedelta(seconds=1))
self.assertEqual(
DurationParser.positive_duration("27.5 secs"),
dt.timedelta(seconds=27.5))
self.assertEqual(
DurationParser.positive_duration("1 second"), dt.timedelta(seconds=1))
self.assertEqual(
DurationParser.positive_duration("27.5 seconds"),
dt.timedelta(seconds=27.5))
def test_minutes(self):
self.assertEqual(
DurationParser.positive_duration("27.5m"), dt.timedelta(minutes=27.5))
self.assertEqual(
DurationParser.positive_duration("1 min"), dt.timedelta(minutes=1))
self.assertEqual(
DurationParser.positive_duration("27.5 mins"),
dt.timedelta(minutes=27.5))
self.assertEqual(
DurationParser.positive_duration("1 minute"), dt.timedelta(minutes=1))
self.assertEqual(
DurationParser.positive_duration("27.5 minutes"),
dt.timedelta(minutes=27.5))
def test_hours(self):
self.assertEqual(
DurationParser.positive_duration("27.5h"), dt.timedelta(hours=27.5))
self.assertEqual(
DurationParser.positive_duration("0.1 h"), dt.timedelta(hours=0.1))
self.assertEqual(
DurationParser.positive_duration("27.5 hrs"), dt.timedelta(hours=27.5))
self.assertEqual(
DurationParser.positive_duration("1 hour"), dt.timedelta(hours=1))
self.assertEqual(
DurationParser.positive_duration("27.5 hours"),
dt.timedelta(hours=27.5))
class PathParserTestCase(CrossbenchFakeFsTestCase):
def test_existing_path(self):
path = pth.LocalPath("foo.txt")
self.assertFalse(path.exists())
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.existing_path(str(path))
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.existing_path(path)
path.touch()
self.assertEqual(path, PathParser.existing_path(str(path)))
self.assertEqual(path, PathParser.existing_path(path))
def test_not_existing_path(self):
path = pth.LocalPath("foo.txt")
self.assertFalse(path.exists())
self.assertEqual(path, PathParser.not_existing_path(str(path)))
self.assertEqual(path, PathParser.not_existing_path(path))
path.touch()
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.not_existing_path(str(path))
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.not_existing_path(path)
class ObjectParserTestCase(CrossbenchFakeFsTestCase):
@override
def setUp(self):
super().setUp()
self._json_test_data = {"int": 1, "array": [1, "2"]}
def test_parse_any_str(self):
self.assertEqual(ObjectParser.any_str(""), "")
self.assertEqual(ObjectParser.any_str("1234"), "1234")
def test_parse_any_str_invalid(self):
invalid: Any
for invalid in (None, 1, [], {}, [1], ["a"], {"a": "a"}):
with self.assertRaises(argparse.ArgumentTypeError) as cm:
ObjectParser.any_str(invalid)
self.assertIn(str(invalid), str(cm.exception))
def test_parse_non_empty_str(self):
self.assertEqual(ObjectParser.non_empty_str("a string"), "a string")
with self.assertRaises(argparse.ArgumentTypeError) as cm:
ObjectParser.non_empty_str("")
self.assertIn("empty", str(cm.exception))
def test_parse_str_or_file_contents(self):
with self.assertRaisesRegex(argparse.ArgumentTypeError, "empty"):
ObjectParser.str_or_file_contents("")
self.assertEqual(
ObjectParser.str_or_file_contents("some data"), "some data")
self.assertEqual(ObjectParser.str_or_file_contents("test.txt"), "test.txt")
def test_parse_str_or_file_contents_file(self):
path = pathlib.Path("./test.txt")
with self.assertRaisesRegex(argparse.ArgumentTypeError, str(path)):
ObjectParser.str_or_file_contents(path)
with self.assertRaisesRegex(argparse.ArgumentTypeError, str(path)):
ObjectParser.str_or_file_contents("./test.txt")
self.fs.create_file(path, contents="test file contents")
self.assertEqual(
ObjectParser.str_or_file_contents(path), "test file contents")
self.assertEqual(ObjectParser.str_or_file_contents(str(path)), "test.txt")
self.assertEqual(
ObjectParser.str_or_file_contents("./test.txt"), "test file contents")
def test_parse_httpx_url_str(self):
for valid in ("http://foo.com", "https://foo.com", "http://localhost:800"):
self.assertEqual(ObjectParser.httpx_url_str(valid), valid)
invalid: Any
for invalid in ("", "ftp://localhost:32", "http://///"):
with self.assertRaises(argparse.ArgumentTypeError) as cm:
_ = ObjectParser.httpx_url_str(invalid)
self.assertIn(invalid, str(cm.exception))
def test_parse_any_int(self):
self.assertEqual(NumberParser.any_int("-123456"), -123456)
self.assertEqual(NumberParser.any_int(-123456), -123456)
self.assertEqual(NumberParser.any_int(float(-123456)), -123456)
self.assertEqual(NumberParser.any_int("-1"), -1)
self.assertEqual(NumberParser.any_int(-1), -1)
self.assertEqual(NumberParser.any_int(float(-1)), -1)
self.assertEqual(NumberParser.any_int("0"), 0)
self.assertEqual(NumberParser.any_int(0), 0)
self.assertEqual(NumberParser.any_int(float(0)), 0)
self.assertEqual(NumberParser.any_int("1"), 1)
self.assertEqual(NumberParser.any_int(1), 1)
self.assertEqual(NumberParser.any_int("123456"), 123456)
self.assertEqual(NumberParser.any_int(123456), 123456)
def test_parse_any_int_strict(self):
self.assertEqual(NumberParser.any_int(float(0), parse_str=False), 0)
self.assertEqual(NumberParser.any_int(1, parse_str=False), 1)
def test_parse_any_int_invalid(self):
invalid: Any
for invalid in ("", "-1.2", -1.2, "1.2", 1.2, "100.001", 100.001, "Nan",
math.nan, "inf", math.inf, "-inf", -math.inf, "invalid",
None):
with self.assertRaises(argparse.ArgumentTypeError):
_ = NumberParser.any_int(invalid)
def test_parse_any_int_invalid_strict(self):
invalid: Any
for invalid in ("", "-1.2", -1.2, "1.2", 1.2, "100.001", 100.001, "Nan",
math.nan, "inf", math.inf, "-inf", -math.inf, "invalid",
None):
with self.assertRaises(argparse.ArgumentTypeError):
_ = NumberParser.any_int(invalid, parse_str=False)
def test_parse_positive_int(self):
self.assertEqual(NumberParser.positive_int("1"), 1)
self.assertEqual(NumberParser.positive_int(1), 1)
self.assertEqual(NumberParser.positive_int("123"), 123)
self.assertEqual(NumberParser.positive_int(123), 123)
def test_parse_positive_int_invalid(self):
invalid: Any
for invalid in ("", "0", 0, "-1", -1, "-1.2", -1.2, "1.2", 1.2, "Nan",
math.nan, "inf", math.inf, "-inf", -math.inf, "invalid",
None):
with self.assertRaises(
argparse.ArgumentTypeError, msg=f"invalid={repr(invalid)}"):
_ = NumberParser.positive_int(invalid)
def test_parse_int_range(self):
self.assertEqual(NumberParser.int_range(min=0, max=10)("1"), 1)
self.assertEqual(NumberParser.int_range(min=0, max=10)(1), 1)
self.assertEqual(NumberParser.int_range(min=0, max=200)("123"), 123)
self.assertEqual(NumberParser.int_range(min=0, max=200)(123), 123)
self.assertEqual(NumberParser.int_range(min=-100, max=200)("-12"), -12)
self.assertEqual(NumberParser.int_range(min=-100, max=200)(-12), -12)
def test_parse_int_range_invalid(self):
with self.assertRaises(AssertionError):
NumberParser.int_range(1, 1)
with self.assertRaises(AssertionError):
NumberParser.int_range(10, 1)
with self.assertRaises(argparse.ArgumentTypeError):
NumberParser.int_range(-1, 10)(-2)
with self.assertRaises(argparse.ArgumentTypeError):
NumberParser.int_range(-1, 10)(11)
def test_parse_positive_int_invalid_strict(self):
invalid: Any
for invalid in ("", "0", 0, "1", "-1", -1, float(-1), "-1.2", -1.2, "1.2",
1.2, "Nan", math.nan, "inf", math.inf, "-inf", -math.inf,
"invalid", None):
with self.assertRaises(
argparse.ArgumentTypeError, msg=f"invalid={repr(invalid)}"):
_ = NumberParser.positive_int(invalid, parse_str=False)
def test_parse_positive_zero_int(self):
self.assertEqual(NumberParser.positive_zero_int("1"), 1)
self.assertEqual(NumberParser.positive_zero_int(1), 1)
self.assertEqual(NumberParser.positive_zero_int(float(1)), 1)
self.assertEqual(NumberParser.positive_zero_int("0"), 0)
self.assertEqual(NumberParser.positive_zero_int(0), 0)
def test_parse_positive_zero_int_invalid(self):
invalid: Any
for invalid in ("", "-1", -1, "-1.2", -1.2, "1.2", 1.2, "NaN", math.nan,
"inf", math.inf, "-inf", -math.inf, "invalid", None):
with self.assertRaises(
argparse.ArgumentTypeError, msg=f"invalid={repr(invalid)}"):
_ = NumberParser.positive_zero_int(invalid)
def test_parse_any_float(self):
self.assertEqual(NumberParser.any_float("-1.2"), -1.2)
self.assertEqual(NumberParser.any_float(-1.2), -1.2)
self.assertEqual(NumberParser.any_float("-1"), -1.0)
self.assertEqual(NumberParser.any_float(-1), -1.0)
self.assertEqual(NumberParser.any_float("0"), 0.0)
self.assertEqual(NumberParser.any_float(0), 0.0)
self.assertEqual(NumberParser.any_float("0.0"), 0.0)
self.assertEqual(NumberParser.any_float(0.0), 0.0)
self.assertEqual(NumberParser.any_float("0.1"), 0.1)
self.assertEqual(NumberParser.any_float(0.1), 0.1)
def test_parse_float_invalid(self):
invalid: Any
for invalid in ("", "abc", "NaN", "inf", "-inf", "invalid"):
with self.assertRaises(argparse.ArgumentTypeError):
_ = NumberParser.positive_zero_float(invalid)
def test_parse_positive_zero_float(self):
self.assertEqual(NumberParser.positive_zero_float("1"), 1.0)
self.assertEqual(NumberParser.positive_zero_float("0"), 0.0)
self.assertEqual(NumberParser.positive_zero_float("0.0"), 0.0)
self.assertEqual(NumberParser.positive_zero_float("1.23"), 1.23)
def test_parse_positive_zero_float_invalid(self):
invalid: Any
for invalid in ("", "-1", "-1.2", "NaN", "inf", "-inf", "invalid"):
with self.assertRaises(argparse.ArgumentTypeError):
_ = NumberParser.positive_zero_float(invalid)
def test_parse_float_range(self):
self.assertEqual(NumberParser.float_range(min=1, max=2)("1"), 1.0)
self.assertEqual(NumberParser.float_range(min=0, max=1)(1), 1.0)
self.assertEqual(NumberParser.float_range(min=0, max=1)("0"), 0.0)
self.assertEqual(NumberParser.float_range(min=0, max=1)(0), 0.0)
self.assertEqual(NumberParser.float_range(min=0, max=1)("0.0"), 0.0)
self.assertEqual(NumberParser.float_range(min=0, max=1)(0.0), 0.0)
self.assertEqual(NumberParser.float_range(min=0, max=11)("1.23"), 1.23)
self.assertEqual(NumberParser.float_range(min=0, max=11)(1.23), 1.23)
self.assertEqual(NumberParser.float_range(min=-2, max=11)("-1.1"), -1.1)
self.assertEqual(NumberParser.float_range(min=-2, max=11)(-1.1), -1.1)
def test_parse_float_range_invalid(self):
with self.assertRaises(AssertionError):
NumberParser.float_range(1, 1)
with self.assertRaises(AssertionError):
NumberParser.float_range(10, 1.0)
with self.assertRaises(AssertionError):
NumberParser.float_range(-10.1, -11.0)
with self.assertRaises(argparse.ArgumentTypeError):
NumberParser.int_range(-1.1, 10.1)(-1.2)
with self.assertRaises(argparse.ArgumentTypeError):
NumberParser.int_range(-1.1, 10.1)(10.2)
def test_parse_port_number(self):
self.assertEqual(NumberParser.port_number(1), 1)
self.assertEqual(NumberParser.port_number("1"), 1)
self.assertEqual(NumberParser.port_number(440), 440)
self.assertEqual(NumberParser.port_number("440"), 440)
self.assertEqual(NumberParser.port_number(65535), 65535)
self.assertEqual(NumberParser.port_number("65535"), 65535)
def test_parse_port_number_zer(self):
self.assertEqual(NumberParser.port_number_zero(0), 0)
self.assertEqual(NumberParser.port_number_zero("0"), 0)
self.assertEqual(NumberParser.port_number(1), 1)
self.assertEqual(NumberParser.port_number("1"), 1)
self.assertEqual(NumberParser.port_number_zero(440), 440)
self.assertEqual(NumberParser.port_number_zero("440"), 440)
self.assertEqual(NumberParser.port_number_zero(65535), 65535)
self.assertEqual(NumberParser.port_number_zero("65535"), 65535)
def test_parse_port_number_invalid(self):
invalid: Any
for invalid in ("", "-1", "-1.2", "6553500", 6553500, "inf", "-inf",
"invalid", 0, "0"):
with self.assertRaises(argparse.ArgumentTypeError):
_ = NumberParser.port_number(invalid)
def test_parse_port_number_invalid_zero(self):
invalid: Any
for invalid in (
"",
"-1",
"-1.2",
"6553500",
"inf",
"-inf",
"invalid",
):
with self.assertRaises(argparse.ArgumentTypeError):
_ = NumberParser.port_number_zero(invalid)
def _json_file_test_helper(self, parser) -> Any:
with self.assertRaises(argparse.ArgumentTypeError):
parser("file")
path = pathlib.Path("file.json")
self.assertFalse(path.exists())
with self.assertRaises(argparse.ArgumentTypeError):
parser(path)
path.touch()
with self.assertRaises(argparse.ArgumentTypeError):
parser(path)
with path.open("w", encoding="utf-8") as f:
f.write("{invalid json data")
with self.assertRaises(argparse.ArgumentTypeError):
parser(path)
# Test very long lines too.
with path.open("w", encoding="utf-8") as f:
f.write("{\n invalid json data" + "." * 100)
with self.assertRaises(argparse.ArgumentTypeError):
parser(path)
with path.open("w", encoding="utf-8") as f:
f.write("""{
'a': {},
'c': }}
""")
with self.assertRaises(argparse.ArgumentTypeError):
parser(path)
with path.open("w", encoding="utf-8") as f:
json.dump(self._json_test_data, f)
str_result = parser(str(path))
path_result = parser(path)
self.assertEqual(str_result, path_result)
return str_result
def test_parse_json_file(self):
result = self._json_file_test_helper(ObjectParser.json_file)
self.assertDictEqual(self._json_test_data, result)
def test_parse_json_file_path(self):
result = self._json_file_test_helper(PathParser.json_file_path)
self.assertEqual(pathlib.Path("file.json"), result)
def test_parse_hjson_file_path(self):
result = self._json_file_test_helper(PathParser.hjson_file_path)
self.assertEqual(pathlib.Path("file.json"), result)
def test_parse_inline_hjson(self):
with self.assertRaisesRegex(argparse.ArgumentTypeError, "empty"):
ObjectParser.inline_hjson("")
with self.assertRaisesRegex(argparse.ArgumentTypeError, "braces"):
ObjectParser.inline_hjson("{")
with self.assertRaisesRegex(argparse.ArgumentTypeError, "braces"):
ObjectParser.inline_hjson("[]")
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.inline_hjson("{invalid json}")
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.inline_hjson("{'asdfas':'asdf}")
self.assertDictEqual(
self._json_test_data,
ObjectParser.inline_hjson(json.dumps(self._json_test_data)))
def test_parse_dir_path(self):
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.dir_path("")
file = pathlib.Path("file")
with self.assertRaises(argparse.ArgumentTypeError) as cm:
PathParser.dir_path(file)
self.assertIn("does not exist", str(cm.exception))
file.touch()
with self.assertRaises(argparse.ArgumentTypeError) as cm:
PathParser.dir_path(file)
self.assertIn("not a folder", str(cm.exception))
folder = pathlib.Path("folder")
folder.mkdir()
self.assertEqual(folder, PathParser.dir_path(folder))
self.assertEqual(folder, PathParser.dir_path(str(folder)))
def test_parse_non_empty_dir_path(self):
folder = pathlib.Path("folder")
folder.mkdir()
with self.assertRaises(argparse.ArgumentTypeError) as cm:
PathParser.non_empty_dir_path(folder)
self.assertIn("empty", str(cm.exception))
(folder / "foo").touch()
self.assertEqual(folder, PathParser.non_empty_dir_path(folder))
self.assertEqual(folder, PathParser.non_empty_dir_path(str(folder)))
def test_parse_non_empty_file_path(self):
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.non_empty_file_path("")
folder = pathlib.Path("folder")
with self.assertRaises(argparse.ArgumentTypeError) as cm:
PathParser.non_empty_file_path(folder)
self.assertIn("does not exist", str(cm.exception))
folder.mkdir()
with self.assertRaises(argparse.ArgumentTypeError) as cm:
PathParser.non_empty_file_path(folder)
self.assertIn("not a file", str(cm.exception))
file = pathlib.Path("file")
file.touch()
with self.assertRaises(argparse.ArgumentTypeError) as cm:
self.assertEqual(file, PathParser.non_empty_file_path(file))
self.assertIn("is an empty file", str(cm.exception))
with file.open("w", encoding="utf-8") as f:
f.write("fooo")
self.assertEqual(file, PathParser.non_empty_file_path(file))
def test_parse_existing_file_path(self):
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.existing_file_path("")
folder = pathlib.Path("folder")
with self.assertRaises(argparse.ArgumentTypeError) as cm:
PathParser.existing_file_path(folder)
self.assertIn("does not exist", str(cm.exception))
folder.mkdir()
with self.assertRaises(argparse.ArgumentTypeError) as cm:
PathParser.existing_file_path(folder)
self.assertIn("not a file", str(cm.exception))
file = pathlib.Path("file")
file.touch()
self.assertEqual(file, PathParser.existing_file_path(file))
def test_parse_path(self):
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.path("")
folder = pathlib.Path("folder")
folder.mkdir()
self.assertEqual(folder, PathParser.path(folder))
file = pathlib.Path("file")
file.touch()
self.assertEqual(file, PathParser.path(file))
def test_parse_any_path_invalid(self):
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.any_path("")
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.any_path(None)
def test_parse_any_path(self):
folder = pathlib.Path("folder")
folder_pure = pathlib.PurePath(folder)
self.assertEqual(folder_pure, PathParser.any_path(folder))
folder.mkdir()
self.assertEqual(folder_pure, PathParser.any_path(folder))
file = pathlib.Path("file")
file_pure = pathlib.PurePath(file)
self.assertEqual(file_pure, PathParser.any_path(file))
file.touch()
self.assertEqual(file_pure, PathParser.any_path(file))
def test_parse_optional_any_path_invalid(self):
with self.assertRaises(argparse.ArgumentTypeError):
PathParser.optional_any_path("")
def test_parse_optional_any_path(self):
self.assertIsNone(PathParser.optional_any_path(None))
folder = pathlib.Path("folder")
folder_pure = pathlib.PurePath(folder)
self.assertEqual(folder_pure, PathParser.optional_any_path(folder))
def test_parse_bool_success(self):
self.assertIs(ObjectParser.bool("true"), True)
self.assertIs(ObjectParser.bool("True"), True)
self.assertIs(ObjectParser.bool(True), True)
self.assertIs(ObjectParser.bool("false"), False)
self.assertIs(ObjectParser.bool("False"), False)
self.assertIs(ObjectParser.bool(False), False)
def test_parse_bool_success_strict(self):
self.assertIs(ObjectParser.bool(True, strict=True), True)
self.assertIs(ObjectParser.bool(False, strict=True), False)
def test_parse_bool_invalid(self):
invalid: Any
for invalid in (1, 0, "1", "0", "", None, [], tuple()):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.bool(invalid)
ObjectParser.bool(invalid, strict=True)
def test_parse_bool_invalid_strict(self):
invalid: Any
for invalid in (None, "False", "false", "True", "true"):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.bool(invalid, strict=True)
def test_parse_optional_bool(self):
self.assertIsNone(ObjectParser.optional_bool(None))
self.assertIs(ObjectParser.optional_bool("true"), True)
self.assertIs(ObjectParser.optional_bool("false"), False)
def test_parse_optional_bool_invalid(self):
invalid: Any
for invalid in (1, 0, "1", "0", "", [], tuple()):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.optional_bool(invalid)
ObjectParser.optional_bool(invalid, strict=True)
def test_parse_optional_bool_invalid_strict(self):
invalid: Any
for invalid in ("False", "false", "True", "true"):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.optional_bool(invalid, strict=True)
def test_parse_sh_cmd(self):
self.assertListEqual(ObjectParser.sh_cmd("ls -al ."), ["ls", "-al", "."])
self.assertListEqual(ObjectParser.sh_cmd("ls -al '.'"), ["ls", "-al", "."])
self.assertListEqual(
ObjectParser.sh_cmd(";ls -al '.'"), [";ls", "-al", "."])
self.assertListEqual(
ObjectParser.sh_cmd(("ls", "-al", ".")), ["ls", "-al", "."])
def test_parse_sh_cmd_invalid(self):
invalid: Any
for invalid in (1, "", None, [], "ls -al \"."):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.sh_cmd(invalid)
def test_parse_dict_invalid(self):
invalid: Any
for invalid in (1, 0, "1", "0", "", None, [], tuple()):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.dict(invalid)
def test_parse_dict(self):
self.assertDictEqual(ObjectParser.dict({}), {})
self.assertDictEqual(ObjectParser.dict({"A": 2}), {"A": 2})
def test_parse_non_empty_dict_invalid(self):
invalid: Any
for invalid in (1, 0, "1", "0", "", None, [], tuple(), {}):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.non_empty_dict(invalid)
def test_parse_non_empty_dict(self):
result = ObjectParser.non_empty_dict({"a": 1})
self.assertDictEqual(result, {"a": 1})
def test_parse_unique_sequence(self):
self.assertListEqual(ObjectParser.unique_sequence([]), [])
self.assertTupleEqual(ObjectParser.unique_sequence(tuple()), tuple())
self.assertListEqual(ObjectParser.unique_sequence([1, 2, 3]), [1, 2, 3])
self.assertTupleEqual(ObjectParser.unique_sequence((1, 2, 3)), (1, 2, 3))
def test_parse_unique_sequence_invalid(self):
with self.assertRaises(argparse.ArgumentTypeError) as cm:
ObjectParser.unique_sequence([1, 1, 2, 2, 2, 3, 5, 5])
self.assertIn("duplicates", str(cm.exception))
self.assertIn("1, 2, 5", str(cm.exception))
def test_parse_unique_sequence_custom_exception(self):
class CustomException(Exception):
pass
with self.assertRaises(CustomException):
ObjectParser.unique_sequence([1, 1], error_cls=CustomException)
def test_parse_unique_sequence_custom_name(self):
with self.assertRaises(argparse.ArgumentTypeError) as cm:
ObjectParser.unique_sequence([1, 1], name="custom test name")
self.assertIn("custom test name", str(cm.exception))
def test_parse_sequence(self):
self.assertListEqual(ObjectParser.sequence([]), [])
self.assertListEqual(ObjectParser.sequence([1, 2]), [1, 2])
self.assertTupleEqual(ObjectParser.sequence(tuple()), tuple())
self.assertTupleEqual(ObjectParser.sequence((1, 2)), (1, 2))
def test_parse_sequence_invalid(self):
invalid: Any
for invalid in ("", "1", 1, {}, {"a": 1}, set(), set((1, 2))):
with self.subTest(invalid=invalid):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.sequence(invalid)
def test_parse_non_empty_sequence(self):
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.non_empty_sequence([])
self.assertListEqual(ObjectParser.non_empty_sequence([1, 2]), [1, 2])
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.non_empty_sequence(tuple())
self.assertTupleEqual(ObjectParser.non_empty_sequence((1, 2)), (1, 2))
def test_parse_non_empty_sequence_invalid(self):
invalid: Any
for invalid in ("", "1", 1, {}, {"a": 1}, set(), set((1, 2)), (), []):
with self.subTest(invalid=invalid):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.non_empty_sequence(invalid)
def test_parse_fuzzy_url_invalid(self):
invalid = ("foo", "x64/123", "x64/www.google.com")
for url in invalid:
with self.subTest(url=url):
with self.assertRaisesRegex(argparse.ArgumentTypeError, url):
ObjectParser.fuzzy_url(url)
def test_parse_fuzzy_url(self):
expected = (
("/foo/bar", "file:///foo/bar"),
("C:/foo/bar", "file://C:/foo/bar"),
("1234.com", "https://1234.com"),
("http://1234.com", "http://1234.com"),
("http://127.0.0.1", "http://127.0.0.1"),
("http://localhost", "http://localhost"),
("127.0.0.1", "https://127.0.0.1"),
("localhost", "https://localhost"),
("test.com", "https://test.com"),
("test.com/", "https://test.com/"),
("test.com/1234", "https://test.com/1234"),
("test.com/bar", "https://test.com/bar"),
("test.com/bar?x=1", "https://test.com/bar?x=1"),
("test.com:1234", "https://test.com:1234"),
("test.com:1234/", "https://test.com:1234/"),
("test.com:1234/56", "https://test.com:1234/56"),
("test.com:1234/56/", "https://test.com:1234/56/"),
("test.com:1234/bar", "https://test.com:1234/bar"),
("test.com:1234/bar?x=1", "https://test.com:1234/bar?x=1"),
("localhost:8123", "https://localhost:8123"),
("localhost:8123", "https://localhost:8123"),
("localhost:8123/", "https://localhost:8123/"),
("localhost:8123/77", "https://localhost:8123/77"),
("localhost:8123/77/", "https://localhost:8123/77/"),
("localhost:8123/bar", "https://localhost:8123/bar"),
("localhost:8123/bar?x=1", "https://localhost:8123/bar?x=1"),
("data:text/html,this is some data",
"data:text/html,this is some data"),
("chrome://extensions", "chrome://extensions"),
)
for url, result in expected:
with self.subTest(url=url):
self.assertEqual(ObjectParser.fuzzy_url_str(url), result)
parsed = ObjectParser.fuzzy_url(url)
self.assertEqual(urlparse.urlunparse(parsed), result)
def test_parse_fuzzy_url_default_scheme(self):
expected = ("test.com", "test.com/", "test.com/bar", "test.com/bar?x=1",
"test.com:1234", "test.com:1234/", "test.com:1234/bar",
"test.com:1234/bar?x=1", "localhost:8123", "localhost:8123/",
"localhost:8123/bar", "localhost:8123/bar?x1")
for url in expected:
with self.subTest(url=url):
result_default = f"https://{url}"
self.assertEqual(ObjectParser.fuzzy_url_str(url), result_default)
parsed = ObjectParser.fuzzy_url(url)
self.assertEqual(urlparse.urlunparse(parsed), result_default)
result_custom = f"ftp://{url}"
self.assertEqual(
ObjectParser.fuzzy_url_str(url, default_scheme="ftp"),
result_custom)
parsed = ObjectParser.fuzzy_url(url, default_scheme="ftp")
self.assertEqual(urlparse.urlunparse(parsed), result_custom)
def test_parse_url(self):
expected = (
("file:///foo/bar", "file:///foo/bar"),
("about:blank", "about:blank"),
("http://test.com/bar", "http://test.com/bar"),
("https://test.com/bar", "https://test.com/bar"),
("http://test.com", "http://test.com"),
("https://test.com/", "https://test.com/"),
("http://test.com/bar", "http://test.com/bar"),
("https://test.com/bar?x=1", "https://test.com/bar?x=1"),
("http://test.com:1234", "http://test.com:1234"),
("https://test.com:1234/", "https://test.com:1234/"),
("http://test.com:1234/bar", "http://test.com:1234/bar"),
("https://test.com:1234/bar?x=1", "https://test.com:1234/bar?x=1"),
("http://localhost:8123", "http://localhost:8123"),
("https://localhost:8123/", "https://localhost:8123/"),
("http://localhost:8123/bar", "http://localhost:8123/bar"),
("https://localhost:8123/bar?x=1", "https://localhost:8123/bar?x=1"),
("data:text/html,this is some data",
"data:text/html,this is some data"),
)
for url, result in expected:
with self.subTest(url=url):
self.assertEqual(ObjectParser.url_str(url), result)
self.assertEqual(ObjectParser.fuzzy_url_str(url), result)
parsed = ObjectParser.url(url)
self.assertEqual(urlparse.urlunparse(parsed), result)
parsed_fuzzy = ObjectParser.fuzzy_url(url)
self.assertEqual(urlparse.urlunparse(parsed_fuzzy), result)
def test_parse_url_invalid(self):
invalid: Any
for invalid in (None, "", {}, "http:// foo .com/bar", "htt p://foo.com",
"http://foo.com:-123/bar"):
with self.subTest(invalid=invalid):
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.url(invalid)
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.url_str(invalid)
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.httpx_url_str(invalid)
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.fuzzy_url_str(invalid)
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.fuzzy_url(invalid)
def test_parse_httpx_url_str_invalid(self):
invalid: Any
for invalid in ("ftp://foo.com:123/bar", "ssh://test.com"):
with self.subTest(invalid=invalid):
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.httpx_url_str(invalid)
def test_parse_url_scheme(self):
url = "ftp://foo.com"
parsed = ObjectParser.url(url)
self.assertEqual(urlparse.urlunparse(parsed), url)
with self.assertRaises(argparse.ArgumentTypeError):
_ = ObjectParser.url(url, schemes=("https",))
parsed = ObjectParser.url(
url, schemes=(
"https",
"ftp",
))
self.assertEqual(urlparse.urlunparse(parsed), url)
def test_parse_regexp(self):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.regexp("\\")
pattern = ObjectParser.regexp("^abc$")
self.assertEqual(pattern.pattern, "^abc$")
def test_bytes_or_file_contents_invalid(self):
with self.assertRaises(argparse.ArgumentTypeError):
ObjectParser.bytes_or_file_contents(None)
def test_bytes_or_file_contents(self):
result: bytes = ObjectParser.bytes_or_file_contents("some data")
self.assertEqual(result, b"some data")
file = pathlib.Path("file")
with file.open("w", encoding="utf-8") as f:
f.write("some data")
result = ObjectParser.bytes_or_file_contents(file)
self.assertEqual(result, b"some data")
def test_proto_or_file_textproto(self):
text_proto_file = pathlib.Path("trace_config.textproto")
text_config: str = """
buffers: {
size_kb: 123456
fill_policy: DISCARD
}
"""
parser = ObjectParser.proto_or_file(trace_config_pb2.TraceConfig)
proto_instance: trace_config_pb2.TraceConfig = parser(text_config)
self.assertEqual(len(proto_instance.buffers), 1)
self.assertEqual(proto_instance.buffers[0].size_kb, 123456)
with text_proto_file.open("w", encoding="utf-8") as f:
f.write(text_config)
proto_instance_2 = parser(text_proto_file)
self.assertEqual(len(proto_instance_2.buffers), 1)
self.assertEqual(proto_instance_2.buffers[0].size_kb, 123456)
self.assertEqual(proto_instance, proto_instance_2)
def test_proto_or_file_invalid(self):
text_proto_file = pathlib.Path("trace_config.textproto")
text_config: str = """
buffers-invalid: {
size_kb: 123456
fill_policy: DISCARD
}
"""
parser = ObjectParser.proto_or_file(trace_config_pb2.TraceConfig)
with self.assertRaisesRegex(argparse.ArgumentTypeError, "buffers-invalid"):
parser(text_config)
with text_proto_file.open("w", encoding="utf-8") as f:
f.write(text_config)
with self.assertRaisesRegex(argparse.ArgumentTypeError, "buffers-invalid"):
parser(text_proto_file)
def test_proto_or_file_binary(self):
text_proto_file = pathlib.Path("trace_config.textproto")
binary_proto_file = pathlib.Path("trace_config.proto")
text_config: str = """
buffers: {
size_kb: 123456
fill_policy: DISCARD
}
"""
with text_proto_file.open("w", encoding="utf-8") as f:
f.write(text_config)
parser = ObjectParser.proto_or_file(trace_config_pb2.TraceConfig)
proto_instance = parser(text_proto_file)
with binary_proto_file.open("wb") as f:
f.write(proto_instance.SerializeToString())
proto_instance_2 = parser(binary_proto_file)
self.assertEqual(len(proto_instance_2.buffers), 1)
self.assertEqual(proto_instance_2.buffers[0].size_kb, 123456)
self.assertEqual(proto_instance, proto_instance_2)
def test_proto_or_file_binary_invalid_format(self):
binary_proto_file = pathlib.Path("trace_config.proto")
self.fs.create_file(binary_proto_file, contents="invalid data")
parser = ObjectParser.proto_or_file(trace_config_pb2.TraceConfig)
with self.assertRaisesRegex(argparse.ArgumentTypeError, "TraceConfig"):
parser(binary_proto_file)
class TimeUnitTestCase(unittest.TestCase):
def test_parse_microseconds(self):
for unit in ("us", "micros", "microseconds"):
self.assertIs(TimeUnit.parse(unit), TimeUnit.MICROSECOND)
def test_parse_milliseconds(self):
for unit in ("ms", "millis", "milliseconds"):
self.assertIs(TimeUnit.parse(unit), TimeUnit.MILLISECOND)
def test_parse_seconds(self):
for unit in ("s", "sec", "secs", "second", "seconds"):
self.assertIs(TimeUnit.parse(unit), TimeUnit.SECOND)
def test_parse_minutes(self):
for unit in ("m", "min", "mins", "minute", "minutes"):
self.assertIs(TimeUnit.parse(unit), TimeUnit.MINUTE)
def test_parse_hours(self):
for unit in ("h", "hrs", "hour", "hours"):
self.assertIs(TimeUnit.parse(unit), TimeUnit.HOUR)
def test_parse_days(self):
for unit in ("d", "day", "days"):
self.assertIs(TimeUnit.parse(unit), TimeUnit.DAY)
def test_parse_weeks(self):
for unit in ("w", "week", "weeks"):
self.assertIs(TimeUnit.parse(unit), TimeUnit.WEEK)
def test_parse_invalid(self):
invalid: Any
for invalid in ("months", "yy", "ww", "i"):
with self.assertRaises(DurationParseError):
TimeUnit.parse(invalid)
def test_to_timedelta(self):
self.assertEqual(
TimeUnit.MICROSECOND.timedelta(123), dt.timedelta(microseconds=123))
self.assertEqual(
TimeUnit.MILLISECOND.timedelta(123), dt.timedelta(milliseconds=123))
self.assertEqual(TimeUnit.SECOND.timedelta(123), dt.timedelta(seconds=123))
self.assertEqual(TimeUnit.MINUTE.timedelta(123), dt.timedelta(minutes=123))
self.assertEqual(TimeUnit.HOUR.timedelta(123), dt.timedelta(hours=123))
self.assertEqual(TimeUnit.DAY.timedelta(123), dt.timedelta(days=123))
self.assertEqual(TimeUnit.WEEK.timedelta(123), dt.timedelta(weeks=123))
if __name__ == "__main__":
test_helper.run_pytest(__file__)