blob: c5349abca5a0ebf44a674d3aeb40ba86de004ba4 [file] [log] [blame]
# Copyright 2024 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
from __future__ import annotations
import difflib
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, \
Protocol, TypeVar
if TYPE_CHECKING:
from crossbench.path import AnyPath
InputT = TypeVar("InputT")
KeyT = TypeVar("KeyT")
ValueT = TypeVar("ValueT")
PathT = TypeVar("PathT", bound=AnyPath)
class GroupTProtocol(Protocol):
def append(self, item: Any, /) -> None:
pass
GroupT = TypeVar("GroupT", bound=GroupTProtocol)
def default_value(value: InputT) -> InputT:
return value
def group_by(
collection: Iterable[InputT],
*,
key: Callable[[InputT], KeyT],
value: Callable[[InputT], ValueT] | None = None,
sort_key: Optional[Callable[[tuple[KeyT, Any]], Any]] = str
) -> dict[KeyT, list[ValueT]]:
"""
Works similar to itertools.groupby but does a global, SQL-style grouping
instead of a line-by-line basis like uniq.
key: a function that returns the grouping key for a group.
value: a function that maps each collection item to a value
stored in groups.
sort_key: an optional function that is passed to sorted(..., key=sort_key)
for sorting the groups.
"""
return _group_by(
collection,
key_fn=key,
value_fn=value or default_value,
group_fn=lambda key: [],
sort_key_fn=sort_key)
def group_by_custom(
collection: Iterable[InputT],
*,
key: Callable[[InputT], KeyT],
group: Callable[[KeyT], GroupT],
value: Optional[Callable[[InputT], ValueT]] = None,
sort_key: Optional[Callable[[tuple[KeyT, Any]], Any]] = str
) -> dict[KeyT, GroupT]:
"""
Works similar to itertools.groupby but does a global, SQL-style grouping
instead of a line-by-line basis like uniq.
key: a function that returns the grouping key for a group
group: a function that accepts a group_key and returns a group object that
has an append() method.
value: a function that maps each collection item to a value stored in groups.
sort_key: an optional function that is passed to sorted(..., key=sort_key)
for sorting the groups.
Note group_by_custom is only required to make the type system happy.
"""
return _group_by(
collection,
key_fn=key,
group_fn=group,
value_fn=value or default_value,
sort_key_fn=sort_key)
def _group_by(
collection: Iterable[InputT],
key_fn: Callable[[InputT], KeyT],
group_fn: Callable[[KeyT], GroupT],
value_fn: Callable[[InputT], ValueT],
sort_key_fn: Optional[Callable[[tuple[KeyT, Any]], Any]] = str
) -> dict[KeyT, GroupT]:
if not key_fn: # type: ignore
raise ValueError("No key function provided")
groups: dict[KeyT, GroupT] = {}
for input_item in collection:
group_key: KeyT = key_fn(input_item)
group_item: ValueT = value_fn(input_item)
if selected_group := groups.get(group_key):
selected_group.append(group_item)
else:
new_group: GroupT = group_fn(group_key)
groups[group_key] = new_group
new_group.append(group_item)
if sort_key_fn:
# sort keys as well for more predictable behavior
return dict(sorted(groups.items(), key=sort_key_fn))
return dict(groups.items())
def close_matches_message(choice: str,
choices: Iterable[str],
name: str = "") -> tuple[str, str | None]:
choices = tuple(choices)
if not choices:
raise ValueError("Expected non-empty choices.")
similar_choices = difflib.get_close_matches(choice, choices)
error_message: str = ""
if name:
error_message = f"Invalid {name}: {repr(choice)}."
alternative: str | None = None
if len(similar_choices) > 1:
error_message += f" Did you mean one of {', '.join(similar_choices)}?"
elif len(similar_choices) == 1:
alternative = similar_choices[0]
error_message += f" Did you mean {repr(alternative)}?"
else:
error_message += f" Choices are {','.join(choices)}"
return error_message, alternative