blob: c2ee8c616cfeb79a5123e61730c646c470c6e03b [file] [log] [blame] [edit]
# Copyright 2023 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
from __future__ import annotations
import argparse
import contextlib
import datetime as dt
import enum
import json
import logging
import math
import re
import shlex
import sys
from typing import (Any, Dict, Final, Iterable, Iterator, List, Optional,
Sequence, Type, TypeVar, Union, cast)
from urllib import parse as urlparse
import colorama
import hjson
from crossbench import helper
from crossbench import path as pth
from crossbench import plt
def type_str(value: Any) -> str:
return type(value).__name__
def parse_path(value: pth.RemotePathLike, name: str = "value") -> pth.LocalPath:
value = parse_not_none(value, "path")
if not value:
raise argparse.ArgumentTypeError("Invalid empty path.")
try:
path = pth.LocalPath(value).expanduser()
except RuntimeError as e:
raise argparse.ArgumentTypeError(
f"Invalid Path {name} {repr(value)}': {e}") from e
return path
def parse_existing_file_path(value: pth.RemotePathLike,
name: str = "value") -> pth.LocalPath:
path = parse_existing_path(value, name)
if not path.is_file():
raise argparse.ArgumentTypeError(f"{name} is not a file: {repr(str(path))}")
return path
def parse_non_empty_file_path(value: pth.RemotePathLike,
name: str = "value") -> pth.LocalPath:
path: pth.LocalPath = parse_existing_file_path(value, name)
if path.stat().st_size == 0:
raise argparse.ArgumentTypeError(
f"{name} is an empty file: {repr(str(path))}")
return path
def parse_file_path(value: pth.RemotePathLike,
name: str = "value") -> pth.LocalPath:
return parse_non_empty_file_path(value, name)
def parse_dir_path(value: pth.RemotePathLike,
name: str = "value") -> pth.LocalPath:
path = parse_existing_path(value, name)
if not path.is_dir():
raise argparse.ArgumentTypeError(
f"{name} is not a folder: '{repr(str(path))}'")
return path
def parse_non_empty_dir_path(value: pth.RemotePathLike,
name: str = "value") -> pth.LocalPath:
dir_path = parse_dir_path(value, name)
for _ in dir_path.iterdir():
return dir_path
raise argparse.ArgumentTypeError(
f"{name} dir must be non empty: {repr(str(dir_path))}")
def parse_existing_path(value: pth.RemotePathLike,
name: str = "value") -> pth.LocalPath:
path = parse_path(value)
if not path.exists():
raise argparse.ArgumentTypeError(
f"{name} path does not exist: {repr(str(path))}")
return path
def parse_not_existing_path(value: pth.RemotePathLike,
name: str = "value") -> pth.LocalPath:
path = parse_path(value)
if path.exists():
raise argparse.ArgumentTypeError(
f"{name} path already exists: {repr(str(path))}")
return path
def parse_binary_path(
value: Optional[pth.RemotePathLike],
name: str = "binary",
platform: Optional[plt.Platform] = None) -> pth.RemotePath:
platform = platform or plt.PLATFORM
maybe_path = platform.path(parse_not_none(value, name))
if platform.is_file(maybe_path):
return maybe_path
maybe_bin = platform.search_binary(maybe_path)
if not maybe_bin:
raise argparse.ArgumentTypeError(f"Unknown binary: {value}")
return maybe_bin
def parse_remote_path(value: Optional[pth.RemotePathLike],
name: str = "value") -> pth.RemotePath:
some_value: pth.RemotePathLike = parse_not_none(value, name)
if not some_value:
raise argparse.ArgumentTypeError(f"Expected non empty path {name}.")
return pth.RemotePath(some_value)
def parse_local_binary_path(
value: Optional[pth.RemotePathLike],
name: str = "binary") -> pth.LocalPath:
return cast(pth.LocalPath, parse_binary_path(value, name))
EnumT = TypeVar("EnumT", bound=enum.Enum)
def parse_enum(label: str, enum_cls: Type[EnumT], data: Any,
choices: Union[Type[EnumT], Iterable[EnumT]]) -> EnumT:
try:
# Try direct conversion, relying on the Enum._missing_ hook:
enum_value = enum_cls(data)
assert isinstance(enum_value, enum.Enum)
assert isinstance(enum_value, enum_cls)
return enum_value
except Exception as e: # pylint: disable=broad-except
logging.debug("Could not auto-convert data '%s' to enum %s: %s", data,
enum_cls, e)
for enum_instance in choices:
if data in (enum_instance, enum_instance.value):
return enum_instance
choices_str: str = ", ".join(repr(item.value) for item in choices) # pytype: disable=missing-parameter
raise argparse.ArgumentTypeError(f"Unknown {label}: {repr(data)}.\n"
f"Choices are {choices_str}.")
def parse_inline_hjson(value: Any) -> Any:
value_str = parse_non_empty_str(value, hjson.__name__)
if value_str[0] != "{" or value_str[-1] != "}":
raise argparse.ArgumentTypeError(
f"Invalid inline {hjson.__name__}, missing braces: '{value_str}'")
try:
return hjson.loads(value_str)
except ValueError as e:
message = _extract_decoding_error(
f"Could not decode inline {hjson.__name__}", value_str, e)
if "eof" in message:
message += "\n Likely missing quotes."
raise argparse.ArgumentTypeError(message) from e
_MAX_LEN = 70
def _extract_decoding_error(message: str, value: pth.RemotePathLike,
e: ValueError) -> str:
lineno = getattr(e, "lineno", -1) - 1
colno = getattr(e, "colno", -1) - 1
if lineno < 0 or colno < 0:
if isinstance(value, pth.LocalPath):
return f"{message}\n {str(e)}"
return f"{message}: {value}\n {str(e)}"
if isinstance(value, pth.RemotePath):
with pth.LocalPath(value).open(encoding="utf-8") as f:
line = f.readlines()[lineno]
else:
line = value.splitlines()[lineno]
if len(line) > _MAX_LEN:
# Only show line around error:
start = colno - _MAX_LEN // 2
end = colno + _MAX_LEN // 2
prefix = "..."
suffix = "..."
if start < 0:
start = 0
end = _MAX_LEN
prefix = ""
elif end > len(line):
end = len(line)
start = len(line) - _MAX_LEN
suffix = ""
colno -= start
line = prefix + line[start:end] + suffix
marker_space = (" " * len(prefix)) + (" " * colno)
else:
marker_space = " " * colno
marker = "_â–²_"
# Adjust line to be aligned with marker size
line = (" " * (len(marker) // 2)) + line
return f"{message}\n {line}\n {marker_space}{marker}\n({str(e)})"
def parse_json_file_path(value: pth.RemotePathLike) -> pth.LocalPath:
path = parse_file_path(value)
with path.open(encoding="utf-8") as f:
try:
json.load(f)
except ValueError as e:
message = _extract_decoding_error(f"Invalid json file '{path}':", path, e)
raise argparse.ArgumentTypeError(message) from e
return path
def parse_hjson_file_path(value: pth.RemotePathLike) -> pth.LocalPath:
path = parse_file_path(value)
with path.open(encoding="utf-8") as f:
try:
hjson.load(f)
except ValueError as e:
message = _extract_decoding_error(
f"Invalid {hjson.__name__} file '{path}':", path, e)
raise argparse.ArgumentTypeError(message) from e
return path
def parse_json_file(value: pth.RemotePathLike) -> Any:
path = parse_file_path(value)
with path.open(encoding="utf-8") as f:
try:
return json.load(f)
except ValueError as e:
message = _extract_decoding_error(f"Invalid json file '{path}':", path, e)
raise argparse.ArgumentTypeError(message) from e
def parse_hjson_file(value: pth.RemotePathLike) -> Any:
path = parse_file_path(value)
with path.open(encoding="utf-8") as f:
try:
return hjson.load(f)
except ValueError as e:
message = _extract_decoding_error(
f"Invalid {hjson.__name__} file '{path}':", path, e)
raise argparse.ArgumentTypeError(message) from e
def parse_non_empty_hjson_file(value: pth.RemotePathLike) -> Any:
data = parse_hjson_file(value)
if not data:
raise argparse.ArgumentTypeError(
f"Expected {hjson.__name__} file with non-empty data, "
f"but got: {hjson.dumps(data)}")
return data
def parse_dict_hjson_file(value: pth.RemotePathLike) -> Any:
data = parse_non_empty_hjson_file(value)
if not isinstance(data, dict):
raise argparse.ArgumentTypeError(
f"Expected object in {hjson.__name__} config '{value}', "
f"but got {type_str(data)}: {repr(data)}")
return data
def parse_dict(value: Any, name: str = "value") -> Dict:
if isinstance(value, dict):
return value
raise argparse.ArgumentTypeError(
f"Expected dict, but {name} is {type_str(value)}: {repr(value)}")
def parse_non_empty_dict(value: Any, name: str = "value") -> Dict:
dict_value = parse_dict(value)
if not dict_value:
raise argparse.ArgumentTypeError(f"Expected {name} to be a non-empty dict.")
return dict_value
def parse_sequence(value: Any, name: str = "value") -> Sequence[Any]:
if isinstance(value, (list, tuple)):
return value
raise argparse.ArgumentTypeError(
f"Expected sequence, but {name} is {type_str(value)}: {repr(value)}")
def parse_non_empty_sequence(value: Any, name: str = "value") -> Sequence[Any]:
sequence_value = parse_sequence(value)
if not sequence_value:
raise argparse.ArgumentTypeError(
f"Expected {name} to be a non-empty sequence.")
return sequence_value
def try_resolve_existing_path(value: str) -> Optional[pth.LocalPath]:
if not value:
return None
maybe_path = pth.LocalPath(value)
if maybe_path.exists():
return maybe_path
maybe_path = maybe_path.expanduser()
if maybe_path.exists():
return maybe_path
return None
def parse_float(value: Any, name: str = "float") -> float:
try:
return float(value)
except ValueError as e:
raise argparse.ArgumentTypeError(f"Invalid {name}: {repr(value)}") from e
def parse_positive_zero_float(value: Any, name: str = "float") -> float:
value_f = parse_float(value, name)
if not math.isfinite(value_f) or value_f < 0:
raise argparse.ArgumentTypeError(
f"Expected {name} >= 0, but got: {value_f}")
return value_f
def parse_int(value: Any, name: str = "value") -> int:
try:
return int(value)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Invalid integer {name}: {repr(value)}") from e
def parse_positive_zero_int(value: Any, name: str = "value") -> int:
value_i = parse_int(value, name)
if value_i < 0:
raise argparse.ArgumentTypeError(
f"Expected integer {name} >= 0, but got: {value_i}")
return value_i
def parse_positive_int(value: Any, name: str = "value") -> int:
value_i = parse_int(value, name)
if not math.isfinite(value_i) or value_i <= 0:
raise argparse.ArgumentTypeError(
f"Expected integer {name} > 0, but got: {value_i}")
return value_i
def parse_port(value: Any, name: str = "port") -> int:
port = parse_int(value, name)
if 1 <= port <= 65535:
return port
raise argparse.ArgumentTypeError(
f"Invalid Port: expected 1 <= {name} <= 65535, but got: {repr(port)}")
def parse_str(value: Any, name: str = "value") -> str:
value = parse_not_none(value, name)
if isinstance(value, str):
return value
raise argparse.ArgumentTypeError(
f"Expected str, but got {type_str(value)}: {value}")
def parse_non_empty_str(value: Any, name: str = "value") -> str:
value = parse_str(value, name)
if not isinstance(value, str):
raise argparse.ArgumentTypeError(
f"Expected non-empty string {name}, "
f"but got {type_str(value)}: {repr(value)}")
if not value:
raise argparse.ArgumentTypeError(f"Non-empty string {name} expected.")
return value
def parse_url_str(value: str,
name: str = "url",
schemes: Optional[Sequence[str]] = None) -> str:
parse_url(value, name, schemes)
return value
def parse_httpx_url_str(value: Any, name: str = "url") -> str:
parse_url(value, name, schemes=("http", "https"))
return value
def parse_base_url(value: str, name: str = "url") -> urlparse.ParseResult:
url_str: str = parse_non_empty_str(value, name)
try:
return urlparse.urlparse(url_str)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Invalid {name}: {repr(value)}, {e}") from e
PATH_PREFIX = re.compile(r"^(?:"
r"(?:\.\.?|~)?|"
r"[a-zA-Z]:"
r")(\\|/)[^\\/]")
PORT_URL_PATH_RE = re.compile(r"^[0-9]+(?:/|$)")
def parse_fuzzy_url_str(value: str,
name: str = "url",
schemes: Sequence[str] = ("http", "https", "about",
"file"),
default_scheme: str = "https") -> str:
parsed = parse_fuzzy_url(value, name, schemes, default_scheme)
return urlparse.urlunparse(parsed)
def parse_fuzzy_url(value: str,
name: str = "url",
schemes: Sequence[str] = ("http", "https", "about", "file"),
default_scheme: str = "https") -> urlparse.ParseResult:
assert default_scheme, "missing default scheme value"
value = parse_non_empty_str(value, name)
if PATH_PREFIX.match(value):
value = f"file://{value}"
else:
parsed = parse_base_url(value)
if not parsed.scheme:
value = f"{default_scheme}://{value}"
# Check if this was a url without a scheme but with ports, which gets
# "wrongly" parsed and the host ends up in result.scheme and port and path
# are merged into result.path.
if parsed.scheme not in schemes and not parsed.netloc:
if PORT_URL_PATH_RE.match(parsed.path):
# foo.com:8080/test => https://foo.com:8080/test
value = f"{default_scheme}://{value}"
schemes = tuple(schemes) + (default_scheme,)
return parse_url(value, name, schemes)
def parse_url(value: str,
name: str = "url",
schemes: Optional[Sequence[str]] = None) -> urlparse.ParseResult:
parsed = parse_base_url(value)
try:
scheme = parsed.scheme
if schemes and scheme not in schemes:
schemes_str = ",".join(map(repr, schemes))
raise argparse.ArgumentTypeError(
f"Invalid {name}: Expected scheme to be one of {schemes_str}, "
f"but got {repr(parsed.scheme)} for url {repr(value)}")
if port := parsed.port:
_ = parse_port(port, f"{name} port")
if scheme in ("file", "about"):
return parsed
hostname = parsed.hostname
if not hostname:
raise argparse.ArgumentTypeError(
f"Missing hostname in {name}: {repr(value)}")
if " " in hostname:
raise argparse.ArgumentTypeError(
f"Hostname in {name} contains invalid space: {repr(value)}")
except ValueError as e:
# Some ParseResult properties trigger errors, wrap all of them
raise argparse.ArgumentTypeError(
f"Invalid {name}: {repr(value)}, {e}") from e
return parsed
def parse_bool(value: Any, name: str = "value") -> bool:
if isinstance(value, bool):
return value
value = str(value).lower()
if value == "true":
return True
if value == "false":
return False
raise argparse.ArgumentTypeError(
f"Expected bool {name} but got {type_str(value)}: {repr(value)}")
NotNoneT = TypeVar("NotNoneT")
def parse_not_none(value: Optional[NotNoneT], name: str = "value") -> NotNoneT:
if value is None:
raise argparse.ArgumentTypeError(f"Expected {name} to be not None.")
return value
def parse_sh_cmd(value: Any) -> List[str]:
value = parse_not_none(value, "shell cmd")
if not value:
raise argparse.ArgumentTypeError(
f"Expected non-empty shell cmd, but got: {value}")
if isinstance(value, (list, tuple)):
for i, part in enumerate(value):
parse_non_empty_str(part, f"cmd[{i}]")
return list(value)
if not isinstance(value, str):
raise argparse.ArgumentTypeError(
f"Expected string or list, but got {type_str(value)}: {value}")
try:
return shlex.split(value)
except ValueError as e:
raise argparse.ArgumentTypeError(f"Invalid shell cmd: {value} ") from e
SequenceT = TypeVar("SequenceT", bound=Sequence)
def parse_unique_sequence(
value: SequenceT,
name: str = "sequence",
error_cls: Type[Exception] = argparse.ArgumentTypeError) -> SequenceT:
unique = set()
duplicates = set()
for item in value:
if item in unique:
duplicates.add(item)
else:
unique.add(item)
if not duplicates:
return value
raise error_cls(f"Unexpected duplicates in {name}: {repr(duplicates)}")
class LateArgumentError(argparse.ArgumentTypeError):
"""Signals argument parse errors after parser.parse_args().
This is used to map errors back to the original argument, much like
argparse.ArgumentError does internally. However, since this happens after
the internal argument parsing we need this custom implementation to print
more descriptive error messages.
"""
def __init__(self, flag: str, message: str) -> None:
super().__init__(message)
self.flag = flag
self.message = message
@contextlib.contextmanager
def late_argument_type_error_wrapper(flag: str) -> Iterator[None]:
"""Converts raised ValueError and ArgumentTypeError to LateArgumentError
that are associated with the given flag.
"""
try:
yield
except Exception as e:
raise LateArgumentError(flag, str(e)) from e
class DurationParseError(argparse.ArgumentTypeError):
pass
class Duration:
@classmethod
def help(cls) -> str:
return "'12.5' == '12.5s', units=['ms', 's', 'm', 'h']"
_DURATION_RE: Final[re.Pattern] = re.compile(
r"(?P<value>(-?\d+(\.\d+)?)) ?(?P<unit>[a-z]+)?")
@classmethod
def _to_timedelta(cls, value: float, suffix: str) -> dt.timedelta:
if suffix in {"ms", "millis", "milliseconds"}:
return dt.timedelta(milliseconds=value)
if suffix in {"s", "sec", "secs", "second", "seconds"}:
return dt.timedelta(seconds=value)
if suffix in {"m", "min", "mins", "minute", "minutes"}:
return dt.timedelta(minutes=value)
if suffix in {"h", "hrs", "hour", "hours"}:
return dt.timedelta(hours=value)
raise DurationParseError(f"Error: {suffix} is not supported for duration. "
"Make sure to use a supported time unit/suffix")
@classmethod
def parse(cls, time_value: Any, name: str = "duration") -> dt.timedelta:
return cls.parse_non_zero(time_value, name)
@classmethod
def parse_non_zero(cls,
time_value: Any,
name: str = "duration") -> dt.timedelta:
duration: dt.timedelta = cls.parse_any(time_value)
if duration.total_seconds() <= 0:
raise DurationParseError(f"Expected non-zero {name}, but got {duration}")
return duration
@classmethod
def parse_zero(cls, time_value: Any, name: str = "duration") -> dt.timedelta:
duration: dt.timedelta = cls.parse_any(time_value, name)
if duration.total_seconds() < 0:
raise DurationParseError(f"Expected positive {name}, but got {duration}")
return duration
@classmethod
def parse_any(cls, time_value: Any, name: str = "duration") -> dt.timedelta:
"""
This function will parse the measurement and the value from string value.
For example:
5s => dt.timedelta(seconds=5)
5m => 5*60 = dt.timedelta(minutes=5)
"""
if isinstance(time_value, dt.timedelta):
return time_value
if isinstance(time_value, (int, float)):
return dt.timedelta(seconds=time_value)
if not time_value:
raise DurationParseError(f"Expected non-empty {name} value.")
if not isinstance(time_value, str):
raise DurationParseError(
f"Unexpected {type_str(time_value)} for {name}: {time_value}")
match = cls._DURATION_RE.fullmatch(time_value)
if match is None:
raise DurationParseError(f"Unknown {name} format: '{time_value}'")
value = match.group("value")
if not value:
raise DurationParseError(
f"Error: {name} value not found."
f"Make sure to include a valid {name} value: '{time_value}'")
time_unit = match.group("unit")
try:
time_value = float(value)
except ValueError as e:
raise DurationParseError(f"{name} must be a valid number, {e}") from e
if not math.isfinite(time_value):
raise DurationParseError(f"{name} must be finite, but got: {time_value}")
if not time_unit:
# If no time unit provided we assume it is in seconds.
return dt.timedelta(seconds=time_value)
return cls._to_timedelta(time_value, time_unit)
@contextlib.contextmanager
def timer(msg: str = "Elapsed Time"):
start_time = dt.datetime.now()
def print_timer():
delta = dt.datetime.now() - start_time
indent = colorama.Cursor.FORWARD() * 3
sys.stdout.write(f"{indent}{msg}: {delta}\r")
with helper.RepeatTimer(interval=0.25, function=print_timer):
yield